Commit b60427e9 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored train_model.py

parent f0fd390c
......@@ -74,7 +74,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
mmd_warmup_epochs=0,
number_of_components=1,
overlap_loss=False,
predictor=True,
predictor=0.0,
prior="standard_normal",
):
super().__init__()
......@@ -87,9 +87,11 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
self.LSTM_units_1 = LSTM_units_1
self.LSTM_units_2 = LSTM_units_2
self.kl_warmup = kl_warmup_epochs
self.kl_warmup_callback = None
self.learn_rate = learn_rate
self.loss = loss
self.mmd_warmup = mmd_warmup_epochs
self.mmd_warmup_callback = None
self.number_of_components = number_of_components
self.overlap_loss = overlap_loss
self.predictor = predictor
......@@ -117,7 +119,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
)
encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24)
gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(
gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams={
"units_conv": conv_filters,
"units_lstm": lstm_units_1,
......@@ -133,7 +135,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
number_of_components=self.number_of_components,
overlap_loss=self.overlap_loss,
predictor=self.predictor,
).build(self.input_shape)[3]
).build(self.input_shape)[3:]
self.kl_warmup_callback = kl_warmup_callback
self.mmd_warmup_callback = mmd_warmup_callback
return gmvaep
......@@ -141,3 +146,4 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
# TODO:
# - We can add as many parameters as we want to the hypermodel!
# with this implementation, predictor, warmup, loss and even number of components can be tuned using BayOpt
# - Number of dense layers close to the latent space as a hyperparameter (!)
......@@ -222,7 +222,7 @@ class SEQ_2_SEQ_AE:
model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
model.compile(
loss=Huber(reduction="sum", delta=self.delta),
loss=Huber(delta=self.delta),
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"],
)
......@@ -253,7 +253,7 @@ class SEQ_2_SEQ_GMVAE:
kl_warmup_epochs: int = 0,
mmd_warmup_epochs: int = 0,
number_of_components: int = 1,
predictor: float = True,
predictor: float = 0.0,
overlap_loss: bool = False,
entropy_reg_weight: float = 0.0,
initialiser_iters: int = int(1e5),
......@@ -542,7 +542,7 @@ class SEQ_2_SEQ_GMVAE:
deepof.model_utils.tfd.Independent(
deepof.model_utils.tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING:, k]),
scale=softplus(gauss[1][..., self.ENCODING :, k]),
),
reinterpreted_batch_ndims=1,
)
......@@ -642,10 +642,10 @@ class SEQ_2_SEQ_GMVAE:
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
gmvaep.compile(
loss=Huber(reduction="sum", delta=self.delta),
loss=Huber(delta=self.delta),
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"],
loss_weights=([1, self.predictor] if self.predictor > 0 else [1]),
loss_weights=([1.0, self.predictor] if self.predictor > 0 else [1.0]),
)
gmvaep.build(input_shape)
......
......@@ -119,7 +119,7 @@ parser.add_argument(
"--predictor",
"-pred",
help="Activates the prediction branch of the variational Seq 2 Seq model. Defaults to True",
default=0,
default=0.0,
type=float,
)
parser.add_argument(
......@@ -264,8 +264,8 @@ input_dict_train = {
print("Preprocessing data...")
preprocessed = batch_preprocess(input_dict_train[input_type])
# Get training and validation sets
X_train = preprocessed[0]
X_val = preprocessed[1]
X_train = tf.cast(preprocessed[0], tf.float32)
X_val = tf.cast(preprocessed[1], tf.float32)
print("Done!")
# Proceed with training mode. Fit autoencoder with the same parameters,
......@@ -342,7 +342,7 @@ if not tune:
if "MMD" in loss and mmd_wu > 0:
callbacks_.append(mmd_warmup_callback)
if predictor == 0:
if predictor == 0.0:
history = gmvaep.fit(
x=X_train,
y=X_train,
......@@ -409,4 +409,4 @@ else:
# TODO:
# - Investigate how goussian filters affect reproducibility (in a systematic way)
# - Investigate how smoothing affects reproducibility (in a systematic way)
# - Check if MCDropout effectively enhances reproducibility or not
\ No newline at end of file
# - Check if MCDropout effectively enhances reproducibility or not
......@@ -146,8 +146,6 @@ def tune_search(
"""
tensorboard_callback, cp_callback, onecycle = callbacks
if hypermodel == "S2SAE": # pragma: no cover
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
......@@ -161,6 +159,12 @@ def tune_search(
predictor=predictor,
overlap_loss=overlap_loss,
)
# if "ELBO" in loss and kl_wu > 0:
# callbacks.append(hypermodel.kl_warmup_callback)
# if "MMD" in loss and mmd_wu > 0:
# callbacks.append(hypermodel.mmd_warmup_callback)
else:
return False
......@@ -177,26 +181,22 @@ def tune_search(
print(tuner.search_space_summary())
tuner.search(
train,
train if predictor == 0 else [train[:-1]],
train if predictor == 0 else [train[:-1], train[1:]],
epochs=n_epochs,
validation_data=(test, test if predictor == 0 else [test[:-1], test[1:]]),
validation_data=(
(test, test) if predictor == 0 else (test[:-1], [test[:-1], test[1:]])
),
verbose=1,
batch_size=256,
callbacks=[
tensorboard_callback,
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=10, restore_best_weights=True
),
cp_callback,
onecycle,
],
callbacks=callbacks,
)
print(tuner.results_summary())
best_hparams = tuner.get_best_hyperparameters(num_trials=1)[0]
best_run = tuner.hypermodel.build(best_hparams)
print(tuner.results_summary())
return best_hparams, best_run
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment