Commit 509f5187 authored by lucas_miranda's avatar lucas_miranda
Browse files

Adds output path option to train_model.py

parent 8e00f0ca
Pipeline #88835 failed with stage
in 18 minutes and 38 seconds
......@@ -72,6 +72,7 @@ def get_callbacks(
predictor: float,
loss: str,
logparam: dict = None,
outpath: str = ".",
) -> List[Union[Any]]:
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
......@@ -92,7 +93,7 @@ def get_callbacks(
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
log_dir = os.path.abspath("logs/fit/{}".format(run_ID))
log_dir = os.path.abspath(os.path.join(outpath, "fit", run_ID))
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1,
......@@ -108,7 +109,7 @@ def get_callbacks(
if cp:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
os.path.join(outpath,"checkpoints", run_ID + "/cp-{epoch:04d}.ckpt"),
verbose=1,
save_best_only=False,
save_weights_only=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment