From 78d8baef71dd84b1fab01648f232fffeb364a370 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Fri, 12 Mar 2021 01:24:49 +0100 Subject: [PATCH] Prototyped KNN_purity callback --- deepof/model_utils.py | 85 ++++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 8a2fe23a..9c895bfb 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): self.iteration += 1 K.set_value(self.model.optimizer.lr, rate) - def on_batch_end(self, epoch, logs=None): - """Add current learning rate as a metric, to check whether scheduling is working properly""" - - return self.last_rate + logs["learning_rate"] = self.last_rate class knn_cluster_purity(tf.keras.callbacks.Callback): @@ -208,51 +205,57 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): def on_epoch_end(self, batch: int, logs): """ Passes samples through the encoder and computes cluster purity on the latent embedding """ - # Get encoer and grouper from full model - cluster_means = [ - layer for layer in self.model.layers if layer.name == "cluster_means" - ][0] - cluster_assignment = [ - layer for layer in self.model.layers if layer.name == "cluster_assignment" - ][0] + if self.validation_data is not None: - encoder = tf.keras.models.Model( - self.model.layers[0].input, cluster_means.output - ) - grouper = tf.keras.models.Model( - self.model.layers[0].input, cluster_assignment.output - ) + # Get encoer and grouper from full model + cluster_means = [ + layer for layer in self.model.layers if layer.name == "cluster_means" + ][0] + cluster_assignment = [ + layer + for layer in self.model.layers + if layer.name == "cluster_assignment" + ][0] - # Use encoder and grouper to predict on validation data - encoding = encoder.predict(self.validation_data) - groups = grouper.predict(self.validation_data) + encoder = tf.keras.models.Model( + self.model.layers[0].input, cluster_means.output + ) + grouper = tf.keras.models.Model( + self.model.layers[0].input, cluster_assignment.output + ) - # Multiply encodings by groups, to get a weighted version of the matrix - encoding = ( - encoding - * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy() - ) - hard_groups = groups.argmax(axis=1) + print(self.validation_data) - # Fit KNN model - knn = NearestNeighbors().fit(encoding) + # Use encoder and grouper to predict on validation data + encoding = encoder.predict(self.validation_data) + groups = grouper.predict(self.validation_data) - # Iterate over samples and compute purity over k neighbours - random_idxs = np.random.choice( - range(encoding.shape[0]), self.samples, replace=False - ) - purity_vector = np.zeros(self.samples) - for i, sample in enumerate(random_idxs): - indexes = knn.kneighbors( - encoding[sample][np.newaxis, :], self.k, return_distance=False + # Multiply encodings by groups, to get a weighted version of the matrix + encoding = ( + encoding + * tf.tile(groups, [1, encoding.shape[1] // groups.shape[1]]).numpy() ) - purity_vector[i] = ( - np.sum(hard_groups[indexes] == hard_groups[sample]) - / self.k - * np.max(groups[sample]) + hard_groups = groups.argmax(axis=1) + + # Fit KNN model + knn = NearestNeighbors().fit(encoding) + + # Iterate over samples and compute purity over k neighbours + random_idxs = np.random.choice( + range(encoding.shape[0]), self.samples, replace=False ) + purity_vector = np.zeros(self.samples) + for i, sample in enumerate(random_idxs): + indexes = knn.kneighbors( + encoding[sample][np.newaxis, :], self.k, return_distance=False + ) + purity_vector[i] = ( + np.sum(hard_groups[indexes] == hard_groups[sample]) + / self.k + * np.max(groups[sample]) + ) - return purity_vector.mean() + logs["knn_cluster_purity"] = purity_vector.mean() class uncorrelated_features_constraint(Constraint): -- GitLab