Commit 3b845054 authored by lucas_miranda's avatar lucas_miranda
Browse files

Changed default hyperparameter values

parent e240953e
Pipeline #87073 failed with stage
in 13 minutes and 55 seconds
......@@ -308,8 +308,8 @@ if not tune:
# To avoid stability issues
tf.keras.backend.clear_session()
run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
X_train, batch_size, variational, predictor, loss,
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
X_train, batch_size, True, variational, predictor, loss,
)
if not variational:
......@@ -406,8 +406,8 @@ else:
hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
X_train, batch_size, variational, predictor, loss
run_ID, tensorboard_callback, onecycle = get_callbacks(
X_train, batch_size, False, variational, predictor, loss
)
best_hyperparameters, best_model = tune_search(
......
......@@ -61,7 +61,7 @@ def load_treatments(train_path):
def get_callbacks(
X_train: np.array, batch_size: int, variational: bool, predictor: float, loss: str,
X_train: np.array, batch_size: int, cp: bool, variational: bool, predictor: float, loss: str,
) -> Tuple:
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
......@@ -81,19 +81,23 @@ def get_callbacks(
log_dir=log_dir, histogram_freq=1, profile_batch=2,
)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
verbose=1,
save_best_only=False,
save_weights_only=True,
save_freq="epoch",
)
onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250, max_rate=0.005,
)
return run_ID, tensorboard_callback, cp_callback, onecycle
callbacks = [run_ID, tensorboard_callback, onecycle]
if cp:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
verbose=1,
save_best_only=False,
save_weights_only=True,
save_freq="epoch",
)
callbacks.append(cp_callback)
return callbacks
def tune_search(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment