Commit 8ddff252 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added nose2body to rule_based_annotation()

parent 4440ac1b
......@@ -186,6 +186,38 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.iteration += 1
K.set_value(self.model.optimizer.lr, rate)
def on_epoch_end(self, epoch, logs=None):
"""Add current learning rate as a metric, to check whether scheduling is working properly"""
pass
class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
Cluster purity callback. Computes assignment purity over K nearest neighbors in the latent space
"""
def __init__(self, trial_data, k=5):
super().__init__()
self.trial_data = trial_data
self.k = k
# noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, batch: int, logs):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
# Get encoer and grouper from full model
encoder =
grouper =
# Use encoder and grouper to predict on trial data
encoding = encoder.predict(self.trial_data)
groups = grouper.predict(self.trial_data)
#
class uncorrelated_features_constraint(Constraint):
"""
......
......@@ -575,6 +575,7 @@ class SEQ_2_SEQ_GMVAE:
self.ENCODING * self.number_of_components
)
// 2,
name="cluster_means",
activation=None,
kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values
# computed from the labels, we could also initialize the prior this way, and update it every N epochs
......@@ -585,6 +586,7 @@ class SEQ_2_SEQ_GMVAE:
self.ENCODING * self.number_of_components
)
// 2,
name="cluster_variances",
activation=None,
activity_regularizer=(
tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment