diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index ad94d9dbbe3089480154ddbc681d4cae87f4b2f9..209eb2d4cf3fa738d79edbba7d517e7f1bc86b30 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -209,8 +209,11 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
 
     """
 
-    def __init__(self, validation_data=None, k=100, samples=10000, log_dir="."):
+    def __init__(
+        self, variational=True, validation_data=None, k=100, samples=10000, log_dir="."
+    ):
         super().__init__()
+        self.variational = variational
         self.validation_data = validation_data
         self.k = k
         self.samples = samples
@@ -220,7 +223,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
     def on_epoch_end(self, epoch, logs=None):
         """ Passes samples through the encoder and computes cluster purity on the latent embedding """
 
-        if self.validation_data is not None:
+        if self.validation_data is not None and self.variational:
 
             # Get encoer and grouper from full model
             cluster_means = [
diff --git a/deepof/models.py b/deepof/models.py
index cd4b2b425a9c6ae9f0221e3e8b8c4878c612a9b2..d34f909daf36f9751e4879fbf300c2e14e2f0314 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -253,7 +253,7 @@ class SEQ_2_SEQ_GMVAE:
         montecarlo_kl: int = 1,
         neuron_control: bool = False,
         number_of_components: int = 1,
-        overlap_loss: float = -1.,
+        overlap_loss: float = -1.0,
         phenotype_prediction: float = 0.0,
         predictor: float = 0.0,
         reg_cat_clusters: bool = False,
@@ -606,7 +606,7 @@ class SEQ_2_SEQ_GMVAE:
             z_gauss = deepof.model_utils.Cluster_overlap(
                 self.ENCODING,
                 self.number_of_components,
-                loss=tf.maximum(0., self.overlap_loss).numpy(),
+                loss=tf.maximum(0.0, self.overlap_loss).numpy(),
             )(z_gauss)
 
         z = tfpl.DistributionLambda(
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index f9e09bdfd3b4f345f65f34db4fe8d0c015c50ab3..77da7c9818f01635b3a5add44830262df08973df 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -117,6 +117,7 @@ def get_callbacks(
         samples=knn_samples,
         validation_data=X_val,
         log_dir=os.path.join(outpath, "metrics"),
+        variational=variational,
     )
 
     onecycle = deepof.model_utils.one_cycle_scheduler(
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index b35a19788010e25c63c73c5d6269710490aec889..f4899f7166461440a621e610b9636d5f243ccd3f 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -117,6 +117,8 @@ def test_autoencoder_fitting(
         phenotype_class=pheno_class,
         predictor=predictor,
         variational=variational,
+        knn_neighbors=10,
+        knn_samples=10,
     )
 
 
@@ -168,6 +170,8 @@ def test_tune_search(
             True,
             True,
             None,
+            knn_neighbors=10,
+            knn_samples=10,
         )
     )[1:]