Commit 3bc7e991 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent 4fa12580
......@@ -384,8 +384,6 @@ else:
generator,
grouper,
gmvaep,
#dead_neuron_rate_callback,
#silhouette_callback,
kl_warmup_callback,
mmd_warmup_callback,
) = SEQ_2_SEQ_GMVAE(
......@@ -404,8 +402,6 @@ else:
callbacks_ = [
tensorboard_callback,
cp_callback,
#dead_neuron_rate_callback,
#silhouette_callback,
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
......
......@@ -344,18 +344,9 @@ class SEQ_2_SEQ_GMVAE:
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# Identity layer controlling clustering and latent space statistics
z = Latent_space_control()(z, z_gauss, z_cat)
# # Latent space callback to control dead (zero) dimensions in the latent space
# dead_neuron_rate_callback = LambdaCallback(
# on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
# )
#
# # Latent space callback to control the latent silhouette clustering index
# silhouette_callback = LambdaCallback(
# on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
# )
# Define and instantiate generator
generator = Model_D1(z)
generator = Model_B1(generator)
......@@ -441,8 +432,6 @@ class SEQ_2_SEQ_GMVAE:
generator,
grouper,
gmvaep,
#dead_neuron_rate_callback,
#silhouette_callback,
kl_warmup_callback,
mmd_warmup_callback,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment