Commit a07fc83b authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored hyperparameter tuning in train_utils.py

parent 563b7880
......@@ -116,6 +116,8 @@ def tune_search(
predictor: float,
project_name: str,
callbacks: List,
n_epochs: int = 30,
n_replicas: int = 1,
) -> Union[bool, Tuple[Any, Any]]:
"""Define the search space using keras-tuner and bayesian optimization
......@@ -123,8 +125,8 @@ def tune_search(
- train (np.array): dataset to train the model on
- test (np.array): dataset to validate the model on
- bayopt_trials (int): number of Bayesian optimization iterations to run
- hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder) or
S2SGMVAE (Gaussian Mixture Variational autoencoder).
- hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder)
or S2SGMVAE (Gaussian Mixture Variational autoencoder).
- k (int) number of components of the Gaussian Mixture
- kl_wu (int): number of epochs for KL divergence warm up
- loss (str): one of [ELBO, MMD, ELBO+MMD]
......@@ -135,6 +137,9 @@ def tune_search(
which tries to predict the next frame from the current one
- project_name (str): ID of the current run
- callbacks (list): list of callbacks for the training loop
- n_epochs (int): optional. Number of epochs to train each run for
- n_replicas (int): optional. Number of replicas per parameter set. Higher values
will yield more robust results, but will affect performance severely
Returns:
- best_hparams (dict): dictionary with the best retrieved hyperparameters
......@@ -163,7 +168,7 @@ def tune_search(
tuner = BayesianOptimization(
hypermodel,
max_trials=bayopt_trials,
executions_per_trial=1,
executions_per_trial=n_replicas,
objective="val_mae",
seed=42,
directory="BayesianOptx",
......@@ -175,13 +180,17 @@ def tune_search(
tuner.search(
train,
train,
epochs=30,
epochs=n_epochs,
validation_data=(test, test),
verbose=1,
batch_size=256,
callbacks=[
tensorboard_callback,
tf.keras.callbacks.EarlyStopping("val_mae", patience=5),
tf.keras.callbacks.EarlyStopping(
"val_mae", patience=10, restore_best_weights=True
),
cp_callback,
onecycle,
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment