diff --git a/deepof/train_model.py b/deepof/train_model.py
index 0b233b97b57eea2a1c439cd5112e8a286de0cfbd..10b2a3825a9655706541f823dbb12050a8da3711 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -268,6 +268,8 @@ logparam = {
     "k": k,
     "loss": loss,
 }
+if pheno_class:
+    logparam["pheno_weight"] = pheno_class
 
 # noinspection PyTypeChecker
 project_coords = project(
@@ -293,7 +295,7 @@ coords = project_coords.get_coords(
     center=animal_id + undercond + "Center",
     align=animal_id + undercond + "Spine_1",
     align_inplace=True,
-    propagate_labels=(pheno_class > 0)
+    propagate_labels=(pheno_class > 0),
 )
 distances = project_coords.get_distances()
 angles = project_coords.get_angles()
@@ -388,17 +390,40 @@ if not tune:
             ),
         ]
 
+        rec = "reconstruction_" if (pheno_class) else ""
+        metrics = [
+            hp.Metric("val_{}mae".format(rec), display_name="val_{}mae".format(rec)),
+            hp.Metric("val_{}mse".format(rec), display_name="val_{}mse".format(rec)),
+        ]
         logparam["run"] = run
+        if pheno_class:
+            logparams.append(
+                hp.HParam(
+                    "pheno_weight",
+                    hp.RealInterval(min_value=0, max_value=1000),
+                    display_name="pheno weight",
+                    description="weight applied to phenotypic classifier from the latent space",
+                )
+            )
+            metrics.append(
+                [
+                    hp.Metric(
+                        "phenotype_prediction_accuracy",
+                        display_name="phenotype_prediction_accuracy",
+                    ),
+                    hp.Metric(
+                        "phenotype_prediction_auc",
+                        display_name="phenotype_prediction_auc",
+                    ),
+                ]
+            )
 
         with tf.summary.create_file_writer(
             os.path.join(output_path, "hparams", run_ID)
         ).as_default():
             hp.hparams_config(
                 hparams=logparams,
-                metrics=[
-                    hp.Metric("val_mae", display_name="val_mae"),
-                    hp.Metric("val_mse", display_name="val_mse"),
-                ],
+                metrics=metrics,
             )
 
         if not variational:
@@ -512,23 +537,16 @@ if not tune:
             )
 
             # Logparams to tensorboard
-            def run(run_dir, hpms):
-                with tf.summary.create_file_writer(run_dir).as_default():
-                    hp.hparams(hpms)  # record the values used in this trial
-                    val_mae = tf.reduce_mean(
-                        tf.keras.metrics.mean_absolute_error(
-                            X_val, gmvaep.predict(X_val)
-                        )
-                    )
-                    val_mse = tf.reduce_mean(
-                        tf.keras.metrics.mean_squared_error(
-                            X_val, gmvaep.predict(X_val)
-                        )
-                    )
-                    tf.summary.scalar("val_mae", val_mae, step=1)
-                    tf.summary.scalar("val_mse", val_mse, step=1)
-
-            run(os.path.join(output_path, "hparams", run_ID), logparam)
+            tensorboard_metric_logging(
+                os.path.join(output_path, "hparams", run_ID),
+                logparam,
+                X_val,
+                y_val,
+                gmvaep,
+                predictor,
+                pheno_class,
+                rec,
+            )
 
         # To avoid stability issues
         tf.keras.backend.clear_session()
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index a8ac8d9dbdb046f957f373cf54095676681ac151..6f227b1d14b2a645c20d18183527cf8c5252f5f0 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -12,6 +12,7 @@ from datetime import date, datetime
 from kerastuner import BayesianOptimization, Hyperband
 from kerastuner import HyperParameters
 from kerastuner_tensorboard_logger import TensorBoardLogger
+from tensorboard.plugins.hparams import api as hp
 from typing import Tuple, Union, Any, List
 import deepof.hypermodels
 import deepof.model_utils
@@ -136,6 +137,52 @@ def get_callbacks(
     return callbacks
 
 
+# noinspection PyUnboundLocalVariable
+def tensorboard_metric_logging(
+    run_dir: str,
+    hpms: Any,
+    X: np.ndarray,
+    y: np.ndarray,
+    model: tensorflow.python.keras.engine.functional.Functional,
+    predictor: float,
+    pheno_class: float,
+    rec: str,
+):
+    output = model.predict(X)
+    if pheno_class or predictor:
+        reconstruction = output[0]
+        prediction = output[1]
+        pheno = output[-1]
+    else:
+        reconstruction = output
+
+    with tf.summary.create_file_writer(run_dir).as_default():
+        hp.hparams(hpms)  # record the values used in this trial
+        val_mae = tf.reduce_mean(
+            tf.keras.metrics.mean_absolute_error(X, reconstruction)
+        )
+        val_mse = tf.reduce_mean(tf.keras.metrics.mean_squared_error(X, reconstruction))
+        tf.summary.scalar("val_{}mae".format(rec), val_mae, step=1)
+        tf.summary.scalar("val_{}mse".format(rec), val_mse, step=1)
+
+        if predictor:
+            pred_mae = tf.reduce_mean(
+                tf.keras.metrics.mean_absolute_error(X, prediction)
+            )
+            pred_mse = tf.reduce_mean(
+                tf.keras.metrics.mean_squared_error(X, prediction)
+            )
+            tf.summary.scalar("val_prediction_mae".format(rec), pred_mae, step=1)
+            tf.summary.scalar("val_prediction_mse".format(rec), pred_mse, step=1)
+
+        if pheno_class:
+            pheno_acc = tf.keras.metrics.Accuracy(y, pheno)
+            pheno_auc = tf.keras.metrics.AUC(y, pheno)
+
+            tf.summary.scalar("phenotype_prediction_accuracy", pheno_acc, step=1)
+            tf.summary.scalar("phenotype_prediction_auc", pheno_auc, step=1)
+
+
 def tune_search(
     data: List[np.array],
     encoding_size: int,