Commit 78d8baef authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent a13aac38
...@@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): ...@@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.iteration += 1 self.iteration += 1
K.set_value(self.model.optimizer.lr, rate) K.set_value(self.model.optimizer.lr, rate)
def on_batch_end(self, epoch, logs=None): logs["learning_rate"] = self.last_rate
"""Add current learning rate as a metric, to check whether scheduling is working properly"""
return self.last_rate
class knn_cluster_purity(tf.keras.callbacks.Callback): class knn_cluster_purity(tf.keras.callbacks.Callback):
...@@ -208,12 +205,16 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -208,12 +205,16 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
def on_epoch_end(self, batch: int, logs): def on_epoch_end(self, batch: int, logs):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """ """ Passes samples through the encoder and computes cluster purity on the latent embedding """
if self.validation_data is not None:
# Get encoer and grouper from full model # Get encoer and grouper from full model
cluster_means = [ cluster_means = [
layer for layer in self.model.layers if layer.name == "cluster_means" layer for layer in self.model.layers if layer.name == "cluster_means"
][0] ][0]
cluster_assignment = [ cluster_assignment = [
layer for layer in self.model.layers if layer.name == "cluster_assignment" layer
for layer in self.model.layers
if layer.name == "cluster_assignment"
][0] ][0]
encoder = tf.keras.models.Model( encoder = tf.keras.models.Model(
...@@ -223,6 +224,8 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -223,6 +224,8 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
self.model.layers[0].input, cluster_assignment.output self.model.layers[0].input, cluster_assignment.output
) )
print(self.validation_data)
# Use encoder and grouper to predict on validation data # Use encoder and grouper to predict on validation data
encoding = encoder.predict(self.validation_data) encoding = encoder.predict(self.validation_data)
groups = grouper.predict(self.validation_data) groups = grouper.predict(self.validation_data)
...@@ -252,7 +255,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -252,7 +255,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
* np.max(groups[sample]) * np.max(groups[sample])
) )
return purity_vector.mean() logs["knn_cluster_purity"] = purity_vector.mean()
class uncorrelated_features_constraint(Constraint): class uncorrelated_features_constraint(Constraint):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment