Skip to content
Snippets Groups Projects
Commit f78f765f authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Implemented weight saving callback in model_training.py

parent b834f5e5
No related branches found
No related tags found
No related merge requests found
......@@ -371,10 +371,10 @@ if not variational:
validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[
tensorboard_callback,
cp_callback,
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
cp_callback,
],
)
......@@ -384,6 +384,8 @@ else:
generator,
grouper,
gmvaep,
dead_neuron_rate_callback,
silhouette_callback,
kl_warmup_callback,
mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE(
......@@ -401,15 +403,17 @@ else:
callbacks_ = [
tensorboard_callback,
cp_callback,
dead_neuron_rate_callback,
silhouette_callback,
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
cp_callback,
]
if "ELBO" in loss:
if "ELBO" in loss and kl_wu > 0:
callbacks_.append(kl_warmup_callback)
if "MMD" in loss:
if "MMD" in loss and mmd_wu > 0:
callbacks_.append(mmd_warmup_callback)
if not predictor:
......
# @author lucasmiranda42
from keras import backend as K
from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
......@@ -150,3 +151,26 @@ class MMDiscrepancyLayer(Layer):
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
return z
class Latent_space_control(Layer):
""" Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
def __init__(self, *args, **kwargs):
super(Latent_space_control, self).__init__(*args, **kwargs)
def call(self, z, z_gauss, z_cat, **kwargs):
# Adds metric that monitors dead neurons in the latent space
self.add_metric(
tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
)
# Adds Silhouette score controling overlap between clusters
hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
self.add_metric(silhouette, aggregation="mean", name="silhouette")
return z
......@@ -344,6 +344,18 @@ class SEQ_2_SEQ_GMVAE:
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# z = Latent_space_control()(z, z_gauss, z_cat)
# Latent space callback to control dead (zero) dimensions in the latent space
dead_neuron_rate_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
)
# Latent space callback to control the latent silhouette clustering index
silhouette_callback = LambdaCallback(
on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
)
# Define and instantiate generator
generator = Model_D1(z)
generator = Model_B1(generator)
......@@ -429,6 +441,8 @@ class SEQ_2_SEQ_GMVAE:
generator,
grouper,
gmvaep,
dead_neuron_rate_callback,
silhouette_callback,
kl_warmup_callback,
mmd_warmup_callback,
)
......@@ -437,7 +451,7 @@ class SEQ_2_SEQ_GMVAE:
# TODO:
# - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
# - Clustering metrics for model selection and aid training (eg early stopping)
# - Silhouette / likelihood / classifier accuracy metrics
# - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics
# - design clustering-conscious hyperparameter tuing pipeline
# TODO (in the non-immediate future):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment