Commit a72a7a81 authored by lucas_miranda's avatar lucas_miranda
Browse files

Modified cluster purity computation. Instead of KNN, we now look at...

Modified cluster purity computation. Instead of KNN, we now look at neighborhoods of a predefined radius
parent 3bd7d6ba
......@@ -216,7 +216,7 @@ class neighbor_cluster_purity(tf.keras.callbacks.Callback):
super().__init__()
self.variational = variational
self.validation_data = validation_data
self.r = r # Make radius default depend on encoding dimensions
self.r = r # Make radius default depend on encoding dimensions
self.samples = samples
self.log_dir = log_dir
......
......@@ -252,7 +252,7 @@ class SEQ_2_SEQ_GMVAE:
montecarlo_kl: int = 1,
neuron_control: bool = False,
number_of_components: int = 1,
overlap_loss: float = -1.0,
overlap_loss: float = 0.0,
phenotype_prediction: float = 0.0,
predictor: float = 0.0,
reg_cat_clusters: bool = False,
......@@ -599,7 +599,7 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = deepof.model_utils.Cluster_overlap(
self.ENCODING,
self.number_of_components,
loss=tf.maximum(0.0, self.overlap_loss).numpy(),
loss=self.overlap_loss,
)(z_gauss)
z = tfpl.DistributionLambda(
......
......@@ -270,15 +270,14 @@ def test_find_learning_rate():
def test_neighbor_cluster_purity():
X = np.random.uniform(0, 10, [1500, 5])
y = np.random.randint(0, 2, [1500, 1])
X = np.random.uniform(0, 10, [1500, 5, 6])
test_model = deepof.models.SEQ_2_SEQ_GMVAE()
test_model.build(X.shape)
gmvaep = test_model.build(X.shape)[3]
test_model.fit(
gmvaep.fit(
X,
X,
y,
callbacks=deepof.model_utils.neighbor_cluster_purity(
validation_data=X, variational=True
),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment