From 78d8baef71dd84b1fab01648f232fffeb364a370 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 12 Mar 2021 01:24:49 +0100
Subject: [PATCH] Prototyped KNN_purity callback

---
 deepof/model_utils.py | 85 ++++++++++++++++++++++---------------------
 1 file changed, 44 insertions(+), 41 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index 8a2fe23a..9c895bfb 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
         self.iteration += 1
         K.set_value(self.model.optimizer.lr, rate)
 
-    def on_batch_end(self, epoch, logs=None):
-        """Add current learning rate as a metric, to check whether scheduling is working properly"""
-
-        return self.last_rate
+        logs["learning_rate"] = self.last_rate
 
 
 class knn_cluster_purity(tf.keras.callbacks.Callback):
@@ -208,51 +205,57 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
     def on_epoch_end(self, batch: int, logs):
         """ Passes samples through the encoder and computes cluster purity on the latent embedding """
 
-        # Get encoer and grouper from full model
-        cluster_means = [
-            layer for layer in self.model.layers if layer.name == "cluster_means"
-        ][0]
-        cluster_assignment = [
-            layer for layer in self.model.layers if layer.name == "cluster_assignment"
-        ][0]
+        if self.validation_data is not None:
 
-        encoder = tf.keras.models.Model(
-            self.model.layers[0].input, cluster_means.output
-        )
-        grouper = tf.keras.models.Model(
-            self.model.layers[0].input, cluster_assignment.output
-        )
+            # Get encoer and grouper from full model
+            cluster_means = [
+                layer for layer in self.model.layers if layer.name == "cluster_means"
+            ][0]
+            cluster_assignment = [
+                layer
+                for layer in self.model.layers
+                if layer.name == "cluster_assignment"
+            ][0]
 
-        # Use encoder and grouper to predict on validation data
-        encoding = encoder.predict(self.validation_data)
-        groups = grouper.predict(self.validation_data)
+            encoder = tf.keras.models.Model(
+                self.model.layers[0].input, cluster_means.output
+            )
+            grouper = tf.keras.models.Model(
+                self.model.layers[0].input, cluster_assignment.output
+            )
 
-        # Multiply encodings by groups, to get a weighted version of the matrix
-        encoding = (
-            encoding
-            * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
-        )
-        hard_groups = groups.argmax(axis=1)
+            print(self.validation_data)
 
-        # Fit KNN model
-        knn = NearestNeighbors().fit(encoding)
+            # Use encoder and grouper to predict on validation data
+            encoding = encoder.predict(self.validation_data)
+            groups = grouper.predict(self.validation_data)
 
-        # Iterate over samples and compute purity over k neighbours
-        random_idxs = np.random.choice(
-            range(encoding.shape[0]), self.samples, replace=False
-        )
-        purity_vector = np.zeros(self.samples)
-        for i, sample in enumerate(random_idxs):
-            indexes = knn.kneighbors(
-                encoding[sample][np.newaxis, :], self.k, return_distance=False
+            # Multiply encodings by groups, to get a weighted version of the matrix
+            encoding = (
+                encoding
+                * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy()
             )
-            purity_vector[i] = (
-                np.sum(hard_groups[indexes] == hard_groups[sample])
-                / self.k
-                * np.max(groups[sample])
+            hard_groups = groups.argmax(axis=1)
+
+            # Fit KNN model
+            knn = NearestNeighbors().fit(encoding)
+
+            # Iterate over samples and compute purity over k neighbours
+            random_idxs = np.random.choice(
+                range(encoding.shape[0]), self.samples, replace=False
             )
+            purity_vector = np.zeros(self.samples)
+            for i, sample in enumerate(random_idxs):
+                indexes = knn.kneighbors(
+                    encoding[sample][np.newaxis, :], self.k, return_distance=False
+                )
+                purity_vector[i] = (
+                    np.sum(hard_groups[indexes] == hard_groups[sample])
+                    / self.k
+                    * np.max(groups[sample])
+                )
 
-        return purity_vector.mean()
+            logs["knn_cluster_purity"] = purity_vector.mean()
 
 
 class uncorrelated_features_constraint(Constraint):
-- 
GitLab