From 509f51875b760fb1fa9499d74c5faf7a5bad239c Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 4 Dec 2020 20:58:00 +0100
Subject: [PATCH] Adds output path option to train_model.py

---
 deepof/train_utils.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index f8d1deee..645bc889 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -72,6 +72,7 @@ def get_callbacks(
     predictor: float,
     loss: str,
     logparam: dict = None,
+    outpath: str = ".",
 ) -> List[Union[Any]]:
     """Generates callbacks for model training, including:
     - run_ID: run name, with coarse parameter details;
@@ -92,7 +93,7 @@ def get_callbacks(
         datetime.now().strftime("%Y%m%d-%H%M%S"),
     )
 
-    log_dir = os.path.abspath("logs/fit/{}".format(run_ID))
+    log_dir = os.path.abspath(os.path.join(outpath, "fit", run_ID))
     tensorboard_callback = tf.keras.callbacks.TensorBoard(
         log_dir=log_dir,
         histogram_freq=1,
@@ -108,7 +109,7 @@ def get_callbacks(
 
     if cp:
         cp_callback = tf.keras.callbacks.ModelCheckpoint(
-            "./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt",
+            os.path.join(outpath,"checkpoints", run_ID + "/cp-{epoch:04d}.ckpt"),
             verbose=1,
             save_best_only=False,
             save_weights_only=True,
-- 
GitLab