diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index 4b8acb370b188e2e02ac43a6a338cdbd79cfaba6..c35b15619583e0db5bbb4cd715c2bb11b15e01ad 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -186,6 +186,38 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
         self.iteration += 1
         K.set_value(self.model.optimizer.lr, rate)
 
+    def on_epoch_end(self, epoch, logs=None):
+        """Add current learning rate as a metric, to check whether scheduling is working properly"""
+        pass
+
+
+class knn_cluster_purity(tf.keras.callbacks.Callback):
+    """
+
+    Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space
+
+    """
+
+    def __init__(self, trial_data, k=5):
+        super().__init__()
+        self.trial_data = trial_data
+        self.k = k
+
+    # noinspection PyMethodOverriding,PyTypeChecker
+    def on_epoch_end(self, batch: int, logs):
+        """ Passes samples through the encoder and computes cluster purity on the latent embedding """
+
+        # Get encoer and grouper from full model
+        encoder =
+        grouper =
+
+        # Use encoder and grouper to predict on trial data
+        encoding = encoder.predict(self.trial_data)
+        groups = grouper.predict(self.trial_data)
+
+        #
+
+
 
 class uncorrelated_features_constraint(Constraint):
     """
diff --git a/deepof/models.py b/deepof/models.py
index a310dea1251c1b2aae2d5ef8af9051b1324dd04b..f5c67ac16a606308836fa61a4ba11ff810d1dfbf 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -575,6 +575,7 @@ class SEQ_2_SEQ_GMVAE:
                 self.ENCODING * self.number_of_components
             )
             // 2,
+            name="cluster_means",
             activation=None,
             kernel_initializer=Orthogonal(),  # An alternative is a constant initializer with a matrix of values
             # computed from the labels, we could also initialize the prior this way, and update it every N epochs
@@ -585,6 +586,7 @@ class SEQ_2_SEQ_GMVAE:
                 self.ENCODING * self.number_of_components
             )
             // 2,
+            name="cluster_variances",
             activation=None,
             activity_regularizer=(
                 tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None