diff --git a/deepof/train_model.py b/deepof/train_model.py index 08ebc93fbdcce85bc239275e10b3e91dd7e99096..e75d73795455624267ecf6e19c2d4ff535911d0e 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -40,13 +40,6 @@ parser.add_argument( type=int, default=512, ) -parser.add_argument( - "--bayopt", - "-n", - help="sets the number of Bayesian optimization iterations to run. Default is 25", - type=int, - default=25, -) parser.add_argument( "--components", "-k", @@ -68,10 +61,18 @@ parser.add_argument( type=str2bool, default=False, ) +parser.add_argument( + "--hpt_trials", + "-n", + help="sets the number of hyperparameter tuning iterations to run. Default is 25", + type=int, + default=25, +) parser.add_argument( "--hyperparameter-tuning", "-tune", - help="If True, hyperparameter tuning is performed. See documentation for details", + help="Indicates whether hyperparameters should be tuned either using 'bayopt' of 'hyperband'. " + "See documentation for details", type=str2bool, default=False, ) @@ -187,7 +188,7 @@ args = parser.parse_args() animal_id = args.animal_id arena_dims = args.arena_dims batch_size = args.batch_size -bayopt_trials = args.bayopt +hypertun_trials = args.hpt_trials exclude_bodyparts = list(args.exclude_bodyparts.split(",")) gaussian_filter = args.gaussian_filter hparams = args.hyperparameters @@ -412,7 +413,8 @@ else: best_hyperparameters, best_model = tune_search( data=[X_train, y_train, X_val, y_val], - bayopt_trials=bayopt_trials, + hypertun_trials=hypertun_trials, + hpt_type=tune, hypermodel=hyp, k=k, kl_warmup_epochs=kl_wu, @@ -421,7 +423,7 @@ else: overlap_loss=overlap_loss, pheno_class=pheno_class, predictor=predictor, - project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp), + project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()), callbacks=[ tensorboard_callback, onecycle, @@ -436,12 +438,12 @@ else: # Saves a compiled, untrained version of the best model best_model.build(X_train.shape) best_model.save( - "{}-based_{}_BAYESIAN_OPT.h5".format(input_type, hyp), save_format="tf" + "{}-based_{}_{}.h5".format(input_type, hyp, tune.capitalize()), save_format="tf" ) # Saves the best hyperparameters with open( - "{}-based_{}_BAYESIAN_OPT_params.pickle".format(input_type, hyp), "wb" + "{}-based_{}_{}_params.pickle".format(input_type, hyp, tune.capitalize()), "wb" ) as handle: pickle.dump( best_hyperparameters.values, handle, protocol=pickle.HIGHEST_PROTOCOL diff --git a/deepof/train_utils.py b/deepof/train_utils.py index e06bd0ad2674b5b1bcaf58fca64221e5397be0af..e4c6a33fc72ce7b74be55f2dbd60446ec9419788 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -107,7 +107,8 @@ def get_callbacks( def tune_search( data: List[np.array], - bayopt_trials: int, + hypertun_trials: int, + hpt_type: str, hypermodel: str, k: int, kl_warmup_epochs: int, @@ -126,7 +127,8 @@ def tune_search( Parameters: - train (np.array): dataset to train the model on - test (np.array): dataset to validate the model on - - bayopt_trials (int): number of Bayesian optimization iterations to run + - hypertun_trials (int): number of Bayesian optimization iterations to run + - hpt_type (str): specify one of Bayesian Optimization (bayopt) and Hyperband (hyperband) - hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder) or S2SGMVAE (Gaussian Mixture Variational autoencoder). - k (int) number of components of the Gaussian Mixture @@ -151,6 +153,10 @@ def tune_search( X_train, y_train, X_val, y_val = data + assert hpt_type in ["bayopt", "hyperband"], ( + "Invalid hyperparameter tuning framework. " "Select one of bayopt and hyperband" + ) + if hypermodel == "S2SAE": # pragma: no cover assert ( predictor == 0.0 and pheno_class == 0.0 @@ -172,17 +178,28 @@ def tune_search( else: return False - tuner = Hyperband( - hypermodel, - directory="HyperBandx_{}_{}".format(loss, str(date.today())), - executions_per_trial=n_replicas, - logger=TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"), - max_epochs=bayopt_trials, - objective="val_mae", - project_name=project_name, - seed=42, - tune_new_entries=True, - ) + hpt_params = { + "hypermodel": hypermodel, + "executions_per_trial": n_replicas, + "logger": TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"), + "objective": "val_mae", + "project_name": project_name, + "seed": 42, + "tune_new_entries": True, + } + + if hpt_type == "hyperband": + tuner = Hyperband( + directory="HyperBandx_{}_{}".format(loss, str(date.today())), + max_epochs=hypertun_trials, + **hpt_params + ) + else: + tuner = BayesianOptimization( + directory="BayOpt_{}_{}".format(loss, str(date.today())), + max_trials=hypertun_trials, + **hpt_params + ) print(tuner.search_space_summary())