diff --git a/deepof/train_model.py b/deepof/train_model.py
index f2c69eac751d68260af9d5787c32cf3c9e3d7728..c604c5f943a7d36a6993ffbd8ab31afe9b3d6fa8 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -13,6 +13,7 @@ from deepof.data import *
 from deepof.models import *
 from deepof.utils import *
 from deepof.train_utils import *
+from tensorboard.plugins.hparams import api as hp
 
 parser = argparse.ArgumentParser(
     description="Autoencoder training for DeepOF animal pose recognition"
@@ -106,6 +107,14 @@ parser.add_argument(
     default=10,
     type=int,
 )
+parser.add_argument(
+    "--logparam",
+    "-lp",
+    help="Logs the specified parameter to tensorboard. None by default. "
+    "Supported values are: encoding",
+    default=None,
+    type=str,
+)
 parser.add_argument(
     "--loss",
     "-l",
@@ -202,6 +211,7 @@ hparams = args.hyperparameters
 input_type = args.input_type
 k = args.components
 kl_wu = args.kl_warmup
+logparam = args.logparam
 loss = args.loss
 mmd_wu = args.mmd_warmup
 overlap_loss = args.overlap_loss
@@ -237,6 +247,10 @@ assert input_type in [
 hparams = load_hparams(hparams)
 treatment_dict = load_treatments(train_path)
 
+# Logs hyperparameters  if specified on the --logparam CLI argument
+if logparam == "encoding":
+    logparam = {"encoding": encoding_size}
+
 # noinspection PyTypeChecker
 project_coords = project(
     animal_ids=tuple([animal_id]),
@@ -317,14 +331,19 @@ if not tune:
         tf.keras.backend.clear_session()
 
         run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
-            X_train,
-            batch_size,
-            True,
-            variational,
-            predictor,
-            loss,
+            X_train=X_train,
+            batch_size=batch_size,
+            cp=True,
+            variational=variational,
+            phenotype_class=pheno_class,
+            predictor=predictor,
+            loss=loss,
+            logparam=logparam,
         )
 
+        if logparam == "encoding":
+            encoding_size = hp.HParam("encoding_dim", hp.Discrete(encoding_size))
+
         if not variational:
             encoder, decoder, ae = SEQ_2_SEQ_AE(hparams).build(X_train.shape)
             print(ae.summary())
@@ -414,7 +433,9 @@ if not tune:
             )
 
             gmvaep.save_weights(
-                "GMVAE_loss={}_encoding={}_final_weights.h5".format(loss, encoding_size)
+                "GMVAE_loss={}_encoding={}_run_{}_final_weights.h5".format(
+                    loss, encoding_size, run
+                )
             )
 
         # To avoid stability issues
@@ -426,7 +447,14 @@ else:
     hyp = "S2SGMVAE" if variational else "S2SAE"
 
     run_ID, tensorboard_callback, onecycle = get_callbacks(
-        X_train, batch_size, False, variational, predictor, loss
+        X_train=X_train,
+        batch_size=batch_size,
+        cp=False,
+        variational=variational,
+        phenotype_class=pheno_class,
+        predictor=predictor,
+        loss=loss,
+        logparam=None,
     )
 
     best_hyperparameters, best_model = tune_search(
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 2b10736dec2f574650bae673864ce241c632dfb6..f8d1deee4aa223c682ed4b58602c459e71ac813b 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -68,8 +68,10 @@ def get_callbacks(
     batch_size: int,
     cp: bool,
     variational: bool,
+    phenotype_class: float,
     predictor: float,
     loss: str,
+    logparam: dict = None,
 ) -> List[Union[Any]]:
     """Generates callbacks for model training, including:
     - run_ID: run name, with coarse parameter details;
@@ -77,10 +79,16 @@ def get_callbacks(
     - cp_callback: for checkpoint saving,
     - onecycle: for learning rate scheduling"""
 
-    run_ID = "{}{}{}_{}".format(
+    run_ID = "{}{}{}{}{}_{}".format(
         ("GMVAE" if variational else "AE"),
         ("P" if predictor > 0 and variational else ""),
+        ("_Pheno" if phenotype_class > 0 else ""),
         ("_loss={}".format(loss) if variational else ""),
+        (
+            "_{}={}".format(list(logparam.keys())[0], list(logparam.values())[0])
+            if logparam is not None
+            else ""
+        ),
         datetime.now().strftime("%Y%m%d-%H%M%S"),
     )