Skip to content
Snippets Groups Projects
Commit 509f5187 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Adds output path option to train_model.py

parent 8e00f0ca
No related branches found
No related tags found
No related merge requests found
Pipeline #88835 failed
...@@ -72,6 +72,7 @@ def get_callbacks( ...@@ -72,6 +72,7 @@ def get_callbacks(
predictor: float, predictor: float,
loss: str, loss: str,
logparam: dict = None, logparam: dict = None,
outpath: str = ".",
) -> List[Union[Any]]: ) -> List[Union[Any]]:
"""Generates callbacks for model training, including: """Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details; - run_ID: run name, with coarse parameter details;
...@@ -92,7 +93,7 @@ def get_callbacks( ...@@ -92,7 +93,7 @@ def get_callbacks(
datetime.now().strftime("%Y%m%d-%H%M%S"), datetime.now().strftime("%Y%m%d-%H%M%S"),
) )
log_dir = os.path.abspath("logs/fit/{}".format(run_ID)) log_dir = os.path.abspath(os.path.join(outpath, "fit", run_ID))
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir, log_dir=log_dir,
histogram_freq=1, histogram_freq=1,
...@@ -108,7 +109,7 @@ def get_callbacks( ...@@ -108,7 +109,7 @@ def get_callbacks(
if cp: if cp:
cp_callback = tf.keras.callbacks.ModelCheckpoint( cp_callback = tf.keras.callbacks.ModelCheckpoint(
"./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt", os.path.join(outpath,"checkpoints", run_ID + "/cp-{epoch:04d}.ckpt"),
verbose=1, verbose=1,
save_best_only=False, save_best_only=False,
save_weights_only=True, save_weights_only=True,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment