Commit 6cfdf470 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added hyperband support for hyperparameter tuning

parent 1fea8558
Pipeline #88030 failed with stage
in 13 minutes and 32 seconds
......@@ -40,13 +40,6 @@ parser.add_argument(
type=int,
default=512,
)
parser.add_argument(
"--bayopt",
"-n",
help="sets the number of Bayesian optimization iterations to run. Default is 25",
type=int,
default=25,
)
parser.add_argument(
"--components",
"-k",
......@@ -68,10 +61,18 @@ parser.add_argument(
type=str2bool,
default=False,
)
parser.add_argument(
"--hpt_trials",
"-n",
help="sets the number of hyperparameter tuning iterations to run. Default is 25",
type=int,
default=25,
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
help="If True, hyperparameter tuning is performed. See documentation for details",
help="Indicates whether hyperparameters should be tuned either using 'bayopt' of 'hyperband'. "
"See documentation for details",
type=str2bool,
default=False,
)
......@@ -187,7 +188,7 @@ args = parser.parse_args()
animal_id = args.animal_id
arena_dims = args.arena_dims
batch_size = args.batch_size
bayopt_trials = args.bayopt
hypertun_trials = args.hpt_trials
exclude_bodyparts = list(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
......@@ -412,7 +413,8 @@ else:
best_hyperparameters, best_model = tune_search(
data=[X_train, y_train, X_val, y_val],
bayopt_trials=bayopt_trials,
hypertun_trials=hypertun_trials,
hpt_type=tune,
hypermodel=hyp,
k=k,
kl_warmup_epochs=kl_wu,
......@@ -421,7 +423,7 @@ else:
overlap_loss=overlap_loss,
pheno_class=pheno_class,
predictor=predictor,
project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
callbacks=[
tensorboard_callback,
onecycle,
......@@ -436,12 +438,12 @@ else:
# Saves a compiled, untrained version of the best model
best_model.build(X_train.shape)
best_model.save(
"{}-based_{}_BAYESIAN_OPT.h5".format(input_type, hyp), save_format="tf"
"{}-based_{}_{}.h5".format(input_type, hyp, tune.capitalize()), save_format="tf"
)
# Saves the best hyperparameters
with open(
"{}-based_{}_BAYESIAN_OPT_params.pickle".format(input_type, hyp), "wb"
"{}-based_{}_{}_params.pickle".format(input_type, hyp, tune.capitalize()), "wb"
) as handle:
pickle.dump(
best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL
......
......@@ -107,7 +107,8 @@ def get_callbacks(
def tune_search(
data: List[np.array],
bayopt_trials: int,
hypertun_trials: int,
hpt_type: str,
hypermodel: str,
k: int,
kl_warmup_epochs: int,
......@@ -126,7 +127,8 @@ def tune_search(
Parameters:
- train (np.array): dataset to train the model on
- test (np.array): dataset to validate the model on
- bayopt_trials (int): number of Bayesian optimization iterations to run
- hypertun_trials (int): number of Bayesian optimization iterations to run
- hpt_type (str): specify one of Bayesian Optimization (bayopt) and Hyperband (hyperband)
- hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder)
or S2SGMVAE (Gaussian Mixture Variational autoencoder).
- k (int) number of components of the Gaussian Mixture
......@@ -151,6 +153,10 @@ def tune_search(
X_train, y_train, X_val, y_val = data
assert hpt_type in ["bayopt", "hyperband"], (
"Invalid hyperparameter tuning framework. " "Select one of bayopt and hyperband"
)
if hypermodel == "S2SAE": # pragma: no cover
assert (
predictor == 0.0 and pheno_class == 0.0
......@@ -172,17 +178,28 @@ def tune_search(
else:
return False
tuner = Hyperband(
hypermodel,
directory="HyperBandx_{}_{}".format(loss, str(date.today())),
executions_per_trial=n_replicas,
logger=TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"),
max_epochs=bayopt_trials,
objective="val_mae",
project_name=project_name,
seed=42,
tune_new_entries=True,
)
hpt_params = {
"hypermodel": hypermodel,
"executions_per_trial": n_replicas,
"logger": TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"),
"objective": "val_mae",
"project_name": project_name,
"seed": 42,
"tune_new_entries": True,
}
if hpt_type == "hyperband":
tuner = Hyperband(
directory="HyperBandx_{}_{}".format(loss, str(date.today())),
max_epochs=hypertun_trials,
**hpt_params
)
else:
tuner = BayesianOptimization(
directory="BayOpt_{}_{}".format(loss, str(date.today())),
max_trials=hypertun_trials,
**hpt_params
)
print(tuner.search_space_summary())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment