Commit 56de58c2 authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 77d77a1d
......@@ -186,9 +186,14 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.iteration += 1
K.set_value(self.model.optimizer.lr, rate)
def on_epoch_end(self, epoch, logs=None):
def on_batch_end(self, epoch, logs=None):
"""Add current learning rate as a metric, to check whether scheduling is working properly"""
pass
self.add_metric(
self.last_rate,
aggregation="mean",
name="learning_rate",
)
class knn_cluster_purity(tf.keras.callbacks.Callback):
......@@ -198,25 +203,40 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def __init__(self, trial_data, k=5):
def __init__(self, k=5, samples=1000):
super().__init__()
self.trial_data = trial_data
self.k = k
self.samples = samples
# noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, batch: int, logs):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
# Get encoer and grouper from full model
encoder = 0
grouper = 0
cluster_means = [
layer for layer in self.model.layers if layer.name == "cluster_means"
][0]
cluster_assignment = [
layer for layer in self.model.layers if layer.name == "cluster_assignment"
][0]
encoder = tf.keras.models.Model(
self.model.layers[0].input, cluster_means.output
)
grouper = tf.keras.models.Model(
self.model.layers[0].input, cluster_assignment.output
)
# Use encoder and grouper to predict on trial data
encoding = encoder.predict(self.trial_data)
groups = grouper.predict(self.trial_data)
trial_idxs = np.random.choice(
range(self.validation_data.shape[0]), self.samples
)
trial_data = self.validation_data[trial_idxs]
#
# Use encoder and grouper to predict on validation data
encoding = encoder.predict(trial_data)
groups = grouper.predict(trial_data)
#
class uncorrelated_features_constraint(Constraint):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment