diff --git a/deepof/train_model.py b/deepof/train_model.py index 7702c1d5efafe9767dddd5037847483266432ff9..1221027217923479daa627ce932c6cb729a229b6 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -411,7 +411,7 @@ else: start_epoch=max(kl_wu, mmd_wu), ), ], - n_replicas=3, + n_replicas=1, n_epochs=30, outpath=output_path, ) diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 0aab06b70e1e959deac42df39ba53fe08b139cc5..081911b3a1d8350e42f8f15014bfd53e4b42db2e 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -545,7 +545,9 @@ def tune_search( if hpt_type == "hyperband": tuner = Hyperband( - directory=os.path.join(outpath, "HyperBandx_{}_{}".format(loss, str(date.today()))), + directory=os.path.join( + outpath, "HyperBandx_{}_{}".format(loss, str(date.today())) + ), max_epochs=35, hyperband_iterations=hypertun_trials, factor=2, @@ -553,7 +555,9 @@ def tune_search( ) else: tuner = BayesianOptimization( - directory=os.path.join(outpath, "BayOpt_{}_{}".format(loss, str(date.today()))), + directory=os.path.join( + outpath, "BayOpt_{}_{}".format(loss, str(date.today())) + ), max_trials=hypertun_trials, **hpt_params ) @@ -577,7 +581,7 @@ def tune_search( epochs=n_epochs, validation_data=(Xvals, yvals), verbose=1, - batch_size=32, + batch_size=64, callbacks=callbacks, )