From f78f765f14a2598d659a8e48a934cd913865ff2b Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Fri, 3 Jul 2020 19:03:47 +0200 Subject: [PATCH] Implemented weight saving callback in model_training.py --- model_training.py | 12 ++++++++---- source/model_utils.py | 24 ++++++++++++++++++++++++ source/models.py | 16 +++++++++++++++- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/model_training.py b/model_training.py index 8e76b24b..b8dbb19b 100644 --- a/model_training.py +++ b/model_training.py @@ -371,10 +371,10 @@ if not variational: validation_data=(input_dict_val[input_type], input_dict_val[input_type]), callbacks=[ tensorboard_callback, + cp_callback, tf.keras.callbacks.EarlyStopping( "val_mae", patience=5, restore_best_weights=True ), - cp_callback, ], ) @@ -384,6 +384,8 @@ else: generator, grouper, gmvaep, + dead_neuron_rate_callback, + silhouette_callback, kl_warmup_callback, mmd_warmup_callback, ) = SEQ_2_SEQ_GMVAE( @@ -401,15 +403,17 @@ else: callbacks_ = [ tensorboard_callback, + cp_callback, + dead_neuron_rate_callback, + silhouette_callback, tf.keras.callbacks.EarlyStopping( "val_mae", patience=5, restore_best_weights=True ), - cp_callback, ] - if "ELBO" in loss: + if "ELBO" in loss and kl_wu > 0: callbacks_.append(kl_warmup_callback) - if "MMD" in loss: + if "MMD" in loss and mmd_wu > 0: callbacks_.append(mmd_warmup_callback) if not predictor: diff --git a/source/model_utils.py b/source/model_utils.py index 32f3112f..17447fa2 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -1,6 +1,7 @@ # @author lucasmiranda42 from keras import backend as K +from sklearn.metrics import silhouette_score from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer import tensorflow as tf @@ -150,3 +151,26 @@ class MMDiscrepancyLayer(Layer): self.add_metric(self.beta, aggregation="mean", name="mmd_rate") return z + + +class Latent_space_control(Layer): + """ Identity layer that adds latent space and clustering stats + to the metrics compiled by the model + """ + + def __init__(self, *args, **kwargs): + super(Latent_space_control, self).__init__(*args, **kwargs) + + def call(self, z, z_gauss, z_cat, **kwargs): + + # Adds metric that monitors dead neurons in the latent space + self.add_metric( + tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons" + ) + + # Adds Silhouette score controling overlap between clusters + hard_labels = tf.math.argmax(z_cat, axis=1) + silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32) + self.add_metric(silhouette, aggregation="mean", name="silhouette") + + return z diff --git a/source/models.py b/source/models.py index d1e0a197..a54bd76d 100644 --- a/source/models.py +++ b/source/models.py @@ -344,6 +344,18 @@ class SEQ_2_SEQ_GMVAE: z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z) + # z = Latent_space_control()(z, z_gauss, z_cat) + + # Latent space callback to control dead (zero) dimensions in the latent space + dead_neuron_rate_callback = LambdaCallback( + on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss) + ) + + # Latent space callback to control the latent silhouette clustering index + silhouette_callback = LambdaCallback( + on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32) + ) + # Define and instantiate generator generator = Model_D1(z) generator = Model_B1(generator) @@ -429,6 +441,8 @@ class SEQ_2_SEQ_GMVAE: generator, grouper, gmvaep, + dead_neuron_rate_callback, + silhouette_callback, kl_warmup_callback, mmd_warmup_callback, ) @@ -437,7 +451,7 @@ class SEQ_2_SEQ_GMVAE: # TODO: # - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning # - Clustering metrics for model selection and aid training (eg early stopping) -# - Silhouette / likelihood / classifier accuracy metrics +# - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics # - design clustering-conscious hyperparameter tuing pipeline # TODO (in the non-immediate future): -- GitLab