Commit 2cf52af5 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent 5c4515b1
......@@ -311,7 +311,7 @@ if not variational:
validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[
tensorboard_callback,
tf.keras.callbacks.EarlyStopping("val_mae", patience=5),
tf.keras.callbacks.EarlyStopping("val_mae", patience=5, restore_best_weights=True),
cp_callback,
],
)
......@@ -322,8 +322,8 @@ else:
generator,
grouper,
gmvaep,
kl_warmup_callback,
mmd_warmup_callback,
(mmd_warmup_callback if "MMD" in loss else None),
(kl_warmup_callback if "ELBO" in loss else None),
) = SEQ_2_SEQ_GMVAE(
input_dict_train[input_type].shape,
loss=loss,
......@@ -347,9 +347,9 @@ else:
validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[
tensorboard_callback,
kl_warmup_callback,
mmd_warmup_callback,
tf.keras.callbacks.EarlyStopping("val_mae", patience=5),
(mmd_warmup_callback if "MMD" in loss else None),
(kl_warmup_callback if "ELBO" in loss else None),
tf.keras.callbacks.EarlyStopping("val_mae", patience=5, restore_best_weights=True),
cp_callback,
],
)
......@@ -366,8 +366,8 @@ else:
),
callbacks=[
tensorboard_callback,
kl_warmup_callback,
mmd_warmup_callback,
(mmd_warmup_callback if "MMD" in loss else None),
(kl_warmup_callback if "ELBO" in loss else None),
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment