diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 4b8acb370b188e2e02ac43a6a338cdbd79cfaba6..c35b15619583e0db5bbb4cd715c2bb11b15e01ad 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -186,6 +186,38 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): self.iteration += 1 K.set_value(self.model.optimizer.lr, rate) + def on_epoch_end(self, epoch, logs=None): + """Add current learning rate as a metric, to check whether scheduling is working properly""" + pass + + +class knn_cluster_purity(tf.keras.callbacks.Callback): + """ + + Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space + + """ + + def __init__(self, trial_data, k=5): + super().__init__() + self.trial_data = trial_data + self.k = k + + # noinspection PyMethodOverriding,PyTypeChecker + def on_epoch_end(self, batch: int, logs): + """ Passes samples through the encoder and computes cluster purity on the latent embedding """ + + # Get encoer and grouper from full model + encoder = + grouper = + + # Use encoder and grouper to predict on trial data + encoding = encoder.predict(self.trial_data) + groups = grouper.predict(self.trial_data) + + # + + class uncorrelated_features_constraint(Constraint): """ diff --git a/deepof/models.py b/deepof/models.py index a310dea1251c1b2aae2d5ef8af9051b1324dd04b..f5c67ac16a606308836fa61a4ba11ff810d1dfbf 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -575,6 +575,7 @@ class SEQ_2_SEQ_GMVAE: self.ENCODING * self.number_of_components ) // 2, + name="cluster_means", activation=None, kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values # computed from the labels, we could also initialize the prior this way, and update it every N epochs @@ -585,6 +586,7 @@ class SEQ_2_SEQ_GMVAE: self.ENCODING * self.number_of_components ) // 2, + name="cluster_variances", activation=None, activity_regularizer=( tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None