Commit 78d8baef authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent a13aac38
......@@ -186,10 +186,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.iteration += 1
K.set_value(self.model.optimizer.lr, rate)
def on_batch_end(self, epoch, logs=None):
"""Add current learning rate as a metric, to check whether scheduling is working properly"""
return self.last_rate
logs["learning_rate"] = self.last_rate
class knn_cluster_purity(tf.keras.callbacks.Callback):
......@@ -208,12 +205,16 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
def on_epoch_end(self, batch: int, logs):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
if self.validation_data is not None:
# Get encoer and grouper from full model
cluster_means = [
layer for layer in self.model.layers if layer.name == "cluster_means"
][0]
cluster_assignment = [
layer for layer in self.model.layers if layer.name == "cluster_assignment"
layer
for layer in self.model.layers
if layer.name == "cluster_assignment"
][0]
encoder = tf.keras.models.Model(
......@@ -223,6 +224,8 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
self.model.layers[0].input, cluster_assignment.output
)
print(self.validation_data)
# Use encoder and grouper to predict on validation data
encoding = encoder.predict(self.validation_data)
groups = grouper.predict(self.validation_data)
......@@ -252,7 +255,7 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
* np.max(groups[sample])
)
return purity_vector.mean()
logs["knn_cluster_purity"] = purity_vector.mean()
class uncorrelated_features_constraint(Constraint):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment