diff --git a/deepof/train_utils.py b/deepof/train_utils.py index f8d1deee4aa223c682ed4b58602c459e71ac813b..645bc889e2ef491fc8ac1e8db2942c278d5d771d 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -72,6 +72,7 @@ def get_callbacks( predictor: float, loss: str, logparam: dict = None, + outpath: str = ".", ) -> List[Union[Any]]: """Generates callbacks for model training, including: - run_ID: run name, with coarse parameter details; @@ -92,7 +93,7 @@ def get_callbacks( datetime.now().strftime("%Y%m%d-%H%M%S"), ) - log_dir = os.path.abspath("logs/fit/{}".format(run_ID)) + log_dir = os.path.abspath(os.path.join(outpath, "fit", run_ID)) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=1, @@ -108,7 +109,7 @@ def get_callbacks( if cp: cp_callback = tf.keras.callbacks.ModelCheckpoint( - "./logs/checkpoints/" + run_ID + "/cp-{epoch:04d}.ckpt", + os.path.join(outpath,"checkpoints", run_ID + "/cp-{epoch:04d}.ckpt"), verbose=1, save_best_only=False, save_weights_only=True,