Commit 3bc7e991 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent 4fa12580
...@@ -384,8 +384,6 @@ else: ...@@ -384,8 +384,6 @@ else:
generator, generator,
grouper, grouper,
gmvaep, gmvaep,
#dead_neuron_rate_callback,
#silhouette_callback,
kl_warmup_callback, kl_warmup_callback,
mmd_warmup_callback, mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE( ) = SEQ_2_SEQ_GMVAE(
...@@ -404,8 +402,6 @@ else: ...@@ -404,8 +402,6 @@ else:
callbacks_ = [ callbacks_ = [
tensorboard_callback, tensorboard_callback,
cp_callback, cp_callback,
#dead_neuron_rate_callback,
#silhouette_callback,
tf.keras.callbacks.EarlyStopping( tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True "val_mae", patience=5, restore_best_weights=True
), ),
......
...@@ -344,18 +344,9 @@ class SEQ_2_SEQ_GMVAE: ...@@ -344,18 +344,9 @@ class SEQ_2_SEQ_GMVAE:
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z) z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# Identity layer controlling clustering and latent space statistics
z = Latent_space_control()(z, z_gauss, z_cat) z = Latent_space_control()(z, z_gauss, z_cat)
# # Latent space callback to control dead (zero) dimensions in the latent space
# dead_neuron_rate_callback = LambdaCallback(
# on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
# )
#
# # Latent space callback to control the latent silhouette clustering index
# silhouette_callback = LambdaCallback(
# on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
# )
# Define and instantiate generator # Define and instantiate generator
generator = Model_D1(z) generator = Model_D1(z)
generator = Model_B1(generator) generator = Model_B1(generator)
...@@ -441,8 +432,6 @@ class SEQ_2_SEQ_GMVAE: ...@@ -441,8 +432,6 @@ class SEQ_2_SEQ_GMVAE:
generator, generator,
grouper, grouper,
gmvaep, gmvaep,
#dead_neuron_rate_callback,
#silhouette_callback,
kl_warmup_callback, kl_warmup_callback,
mmd_warmup_callback, mmd_warmup_callback,
) )
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment