Commit d6ad5d0f authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent 2cf52af5
......@@ -311,7 +311,9 @@ if not variational:
validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[
tensorboard_callback,
tf.keras.callbacks.EarlyStopping("val_mae", patience=5, restore_best_weights=True),
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
cp_callback,
],
)
......@@ -322,8 +324,8 @@ else:
generator,
grouper,
gmvaep,
(mmd_warmup_callback if "MMD" in loss else None),
(kl_warmup_callback if "ELBO" in loss else None),
mmd_warmup_callback,
kl_warmup_callback,
) = SEQ_2_SEQ_GMVAE(
input_dict_train[input_type].shape,
loss=loss,
......@@ -337,6 +339,19 @@ else:
print(gmvaep.summary())
callbacks_ = [
tensorboard_callback,
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
cp_callback,
]
if "ELBO" in loss:
callbacks_.append(kl_warmup_callback)
if "MMD" in loss:
callbacks_.append(mmd_warmup_callback)
if not predictor:
history = gmvaep.fit(
x=input_dict_train[input_type],
......@@ -345,13 +360,7 @@ else:
batch_size=512,
verbose=1,
validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[
tensorboard_callback,
(mmd_warmup_callback if "MMD" in loss else None),
(kl_warmup_callback if "ELBO" in loss else None),
tf.keras.callbacks.EarlyStopping("val_mae", patience=5, restore_best_weights=True),
cp_callback,
],
callbacks=callbacks_,
)
else:
history = gmvaep.fit(
......@@ -364,13 +373,5 @@ else:
input_dict_val[input_type][:-1],
[input_dict_val[input_type][:-1], input_dict_val[input_type][1:]],
),
callbacks=[
tensorboard_callback,
(mmd_warmup_callback if "MMD" in loss else None),
(kl_warmup_callback if "ELBO" in loss else None),
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=5, restore_best_weights=True
),
cp_callback,
],
callbacks=callbacks_,
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment