Commit 4fa12580 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent f78f765f
......@@ -384,8 +384,8 @@ else:
generator,
grouper,
gmvaep,
dead_neuron_rate_callback,
silhouette_callback,
#dead_neuron_rate_callback,
#silhouette_callback,
kl_warmup_callback,
mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE(
......@@ -404,8 +404,8 @@ else:
callbacks_ = [
tensorboard_callback,
cp_callback,
dead_neuron_rate_callback,
silhouette_callback,
#dead_neuron_rate_callback,
#silhouette_callback,
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
......
......@@ -344,17 +344,17 @@ class SEQ_2_SEQ_GMVAE:
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# z = Latent_space_control()(z, z_gauss, z_cat)
# Latent space callback to control dead (zero) dimensions in the latent space
dead_neuron_rate_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
)
# Latent space callback to control the latent silhouette clustering index
silhouette_callback = LambdaCallback(
on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
)
z = Latent_space_control()(z, z_gauss, z_cat)
# # Latent space callback to control dead (zero) dimensions in the latent space
# dead_neuron_rate_callback = LambdaCallback(
# on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
# )
#
# # Latent space callback to control the latent silhouette clustering index
# silhouette_callback = LambdaCallback(
# on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
# )
# Define and instantiate generator
generator = Model_D1(z)
......@@ -441,8 +441,8 @@ class SEQ_2_SEQ_GMVAE:
generator,
grouper,
gmvaep,
dead_neuron_rate_callback,
silhouette_callback,
#dead_neuron_rate_callback,
#silhouette_callback,
kl_warmup_callback,
mmd_warmup_callback,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment