Commit f0fd390c authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored train_model.py

parent 93ba8a91
Pipeline #83731 passed with stage
in 33 minutes and 21 seconds
......@@ -542,7 +542,7 @@ class SEQ_2_SEQ_GMVAE:
deepof.model_utils.tfd.Independent(
deepof.model_utils.tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING :, k]),
scale=softplus(gauss[1][..., self.ENCODING:, k]),
),
reinterpreted_batch_ndims=1,
)
......@@ -641,14 +641,8 @@ class SEQ_2_SEQ_GMVAE:
_x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
def huber_loss(x_, x_decoded_mean_): # pragma: no cover
"""Computes huber loss with a fixed delta"""
huber = Huber(reduction="sum", delta=self.delta)
return input_shape[1:] * huber(x_, x_decoded_mean_)
gmvaep.compile(
loss=huber_loss,
loss=Huber(reduction="sum", delta=self.delta),
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"],
loss_weights=([1, self.predictor] if self.predictor > 0 else [1]),
......
......@@ -13,6 +13,7 @@ from deepof.data import *
from deepof.models import *
from deepof.utils import *
from train_utils import *
from tensorboard.plugins.hparams import api as hp
from tensorflow import keras
parser = argparse.ArgumentParser(
......@@ -61,14 +62,6 @@ parser.add_argument(
type=str2bool,
default=False,
)
parser.add_argument(
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, S2SVAEP, "
"S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
type=str,
default="S2SVAE",
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
......@@ -183,7 +176,6 @@ bayopt_trials = args.bayopt
exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
hyp = args.hypermodel
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
......@@ -270,17 +262,12 @@ input_dict_train = {
}
print("Preprocessing data...")
for key, value in input_dict_train.items():
input_dict_train[key] = batch_preprocess(value)
print("Done!")
print("Creating training and validation sets...")
preprocessed = batch_preprocess(input_dict_train[input_type])
# Get training and validation sets
X_train = input_dict_train[input_type][0]
X_val = input_dict_train[input_type][1]
X_train = preprocessed[0]
X_val = preprocessed[1]
print("Done!")
# Proceed with training mode. Fit autoencoder with the same parameters,
# as many times as specified by runs
if not tune:
......@@ -384,6 +371,8 @@ if not tune:
else:
# Runs hyperparameter tuning with the specified parameters and saves the results
hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
)
......
......@@ -146,10 +146,9 @@ def tune_search(
"""
print(callbacks)
tensorboard_callback, cp_callback, onecycle = callbacks
if hypermodel == "S2SAE":
if hypermodel == "S2SAE": # pragma: no cover
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
elif hypermodel == "S2SGMVAE":
......@@ -179,9 +178,9 @@ def tune_search(
tuner.search(
train,
train,
train if predictor == 0 else [train[:-1], train[1:]],
epochs=n_epochs,
validation_data=(test, test),
validation_data=(test, test if predictor == 0 else [test[:-1], test[1:]]),
verbose=1,
batch_size=256,
callbacks=[
......
......@@ -80,12 +80,12 @@ def str2bool(v: str) -> bool:
"""
if isinstance(v, bool):
return v
return v # pragma: no cover
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
else: # pragma: no cover
raise argparse.ArgumentTypeError("Boolean compatible value expected.")
......
This diff is collapsed.
......@@ -80,7 +80,7 @@ def test_get_callbacks(
elements=st.floats(min_value=0.0, max_value=1,),
),
batch_size=st.integers(min_value=128, max_value=512),
hypermodel=st.one_of(st.just("S2SAE"), st.just("S2SGMVAE")),
hypermodel=st.just("S2SGMVAE"),
k=st.integers(min_value=1, max_value=10),
kl_wu=st.integers(min_value=0, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD")),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment