From dd62bd2e44fc3c34fc3ded04bef4cc414426305d Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 12 Mar 2021 04:08:04 +0100
Subject: [PATCH] Implemented KNN_purity callback

---
 deepof/model_utils.py     | 7 +++++--
 deepof/models.py          | 4 ++--
 deepof/train_utils.py     | 1 +
 tests/test_train_utils.py | 4 ++++
 4 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index ad94d9db..209eb2d4 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -209,8 +209,11 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
 
     """
 
-    def __init__(self, validation_data=None, k=100, samples=10000, log_dir="."):
+    def __init__(
+        self, variational=True, validation_data=None, k=100, samples=10000, log_dir="."
+    ):
         super().__init__()
+        self.variational = variational
         self.validation_data = validation_data
         self.k = k
         self.samples = samples
@@ -220,7 +223,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
     def on_epoch_end(self, epoch, logs=None):
         """ Passes samples through the encoder and computes cluster purity on the latent embedding """
 
-        if self.validation_data is not None:
+        if self.validation_data is not None and self.variational:
 
             # Get encoer and grouper from full model
             cluster_means = [
diff --git a/deepof/models.py b/deepof/models.py
index cd4b2b42..d34f909d 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -253,7 +253,7 @@ class SEQ_2_SEQ_GMVAE:
         montecarlo_kl: int = 1,
         neuron_control: bool = False,
         number_of_components: int = 1,
-        overlap_loss: float = -1.,
+        overlap_loss: float = -1.0,
         phenotype_prediction: float = 0.0,
         predictor: float = 0.0,
         reg_cat_clusters: bool = False,
@@ -606,7 +606,7 @@ class SEQ_2_SEQ_GMVAE:
             z_gauss = deepof.model_utils.Cluster_overlap(
                 self.ENCODING,
                 self.number_of_components,
-                loss=tf.maximum(0., self.overlap_loss).numpy(),
+                loss=tf.maximum(0.0, self.overlap_loss).numpy(),
             )(z_gauss)
 
         z = tfpl.DistributionLambda(
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index f9e09bdf..77da7c98 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -117,6 +117,7 @@ def get_callbacks(
         samples=knn_samples,
         validation_data=X_val,
         log_dir=os.path.join(outpath, "metrics"),
+        variational=variational,
     )
 
     onecycle = deepof.model_utils.one_cycle_scheduler(
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index b35a1978..f4899f71 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -117,6 +117,8 @@ def test_autoencoder_fitting(
         phenotype_class=pheno_class,
         predictor=predictor,
         variational=variational,
+        knn_neighbors=10,
+        knn_samples=10,
     )
 
 
@@ -168,6 +170,8 @@ def test_tune_search(
             True,
             True,
             None,
+            knn_neighbors=10,
+            knn_samples=10,
         )
     )[1:]
 
-- 
GitLab