Commit f78f765f authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent b834f5e5
...@@ -371,10 +371,10 @@ if not variational: ...@@ -371,10 +371,10 @@ if not variational:
validation_data=(input_dict_val[input_type], input_dict_val[input_type]), validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[ callbacks=[
tensorboard_callback, tensorboard_callback,
cp_callback,
tf.keras.callbacks.EarlyStopping( tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True "val_mae", patience=5, restore_best_weights=True
), ),
cp_callback,
], ],
) )
...@@ -384,6 +384,8 @@ else: ...@@ -384,6 +384,8 @@ else:
generator, generator,
grouper, grouper,
gmvaep, gmvaep,
dead_neuron_rate_callback,
silhouette_callback,
kl_warmup_callback, kl_warmup_callback,
mmd_warmup_callback, mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE( ) = SEQ_2_SEQ_GMVAE(
...@@ -401,15 +403,17 @@ else: ...@@ -401,15 +403,17 @@ else:
callbacks_ = [ callbacks_ = [
tensorboard_callback, tensorboard_callback,
cp_callback,
dead_neuron_rate_callback,
silhouette_callback,
tf.keras.callbacks.EarlyStopping( tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True "val_mae", patience=5, restore_best_weights=True
), ),
cp_callback,
] ]
if "ELBO" in loss: if "ELBO" in loss and kl_wu > 0:
callbacks_.append(kl_warmup_callback) callbacks_.append(kl_warmup_callback)
if "MMD" in loss: if "MMD" in loss and mmd_wu > 0:
callbacks_.append(mmd_warmup_callback) callbacks_.append(mmd_warmup_callback)
if not predictor: if not predictor:
......
# @author lucasmiranda42 # @author lucasmiranda42
from keras import backend as K from keras import backend as K
from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer from tensorflow.keras.layers import Layer
import tensorflow as tf import tensorflow as tf
...@@ -150,3 +151,26 @@ class MMDiscrepancyLayer(Layer): ...@@ -150,3 +151,26 @@ class MMDiscrepancyLayer(Layer):
self.add_metric(self.beta, aggregation="mean", name="mmd_rate") self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
return z return z
class Latent_space_control(Layer):
""" Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
def __init__(self, *args, **kwargs):
super(Latent_space_control, self).__init__(*args, **kwargs)
def call(self, z, z_gauss, z_cat, **kwargs):
# Adds metric that monitors dead neurons in the latent space
self.add_metric(
tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
)
# Adds Silhouette score controling overlap between clusters
hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
self.add_metric(silhouette, aggregation="mean", name="silhouette")
return z
...@@ -344,6 +344,18 @@ class SEQ_2_SEQ_GMVAE: ...@@ -344,6 +344,18 @@ class SEQ_2_SEQ_GMVAE:
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z) z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# z = Latent_space_control()(z, z_gauss, z_cat)
# Latent space callback to control dead (zero) dimensions in the latent space
dead_neuron_rate_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
)
# Latent space callback to control the latent silhouette clustering index
silhouette_callback = LambdaCallback(
on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
)
# Define and instantiate generator # Define and instantiate generator
generator = Model_D1(z) generator = Model_D1(z)
generator = Model_B1(generator) generator = Model_B1(generator)
...@@ -429,6 +441,8 @@ class SEQ_2_SEQ_GMVAE: ...@@ -429,6 +441,8 @@ class SEQ_2_SEQ_GMVAE:
generator, generator,
grouper, grouper,
gmvaep, gmvaep,
dead_neuron_rate_callback,
silhouette_callback,
kl_warmup_callback, kl_warmup_callback,
mmd_warmup_callback, mmd_warmup_callback,
) )
...@@ -437,7 +451,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -437,7 +451,7 @@ class SEQ_2_SEQ_GMVAE:
# TODO: # TODO:
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning # - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
# - Clustering metrics for model selection and aid training (eg early stopping) # - Clustering metrics for model selection and aid training (eg early stopping)
# - Silhouette / likelihood / classifier accuracy metrics # - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics
# - design clustering-conscious hyperparameter tuing pipeline # - design clustering-conscious hyperparameter tuing pipeline
# TODO (in the non-immediate future): # TODO (in the non-immediate future):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment