diff --git a/deepof/train_model.py b/deepof/train_model.py
index 2448108afdfbd6061daf8385bc6605a02d0a8a4e..122555f0f70ab9df004c5c6ee95735247a8f802e 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -411,8 +411,7 @@ else:
     )
 
     best_hyperparameters, best_model = tune_search(
-        X_train,
-        X_val,
+        data=[X_train, y_train, X_val, y_val],
         bayopt_trials=bayopt_trials,
         hypermodel=hyp,
         k=k,
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 000f52addb945b64b27d364327653320ac5a2b96..b545268f3061c32f12adbcdb0241920baa517c56 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -97,8 +97,7 @@ def get_callbacks(
 
 
 def tune_search(
-    train: np.array,
-    test: np.array,
+    data: List[np.array],
     bayopt_trials: int,
     hypermodel: str,
     k: int,
@@ -139,12 +138,14 @@ def tune_search(
 
     """
 
+    X_train, y_train, X_val, y_val = data
+
     if hypermodel == "S2SAE":  # pragma: no cover
-        hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
+        hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
 
     elif hypermodel == "S2SGMVAE":
         hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
-            input_shape=train.shape,
+            input_shape=X_train.shape,
             loss=loss,
             number_of_components=k,
             overlap_loss=overlap_loss,
@@ -168,13 +169,22 @@ def tune_search(
 
     print(tuner.search_space_summary())
 
+    Xs, ys = [X_train], [X_train]
+    Xvals, yvals = [X_val], [X_val]
+
+    if predictor > 0.0:
+        Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
+        Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
+
+    if pheno_class > 0.0:
+        ys += [y_train]
+        yvals += [y_val]
+
     tuner.search(
-        train if predictor == 0 else [train[:-1]],
-        train if predictor == 0 else [train[:-1], train[1:]],
+        Xs,
+        ys,
         epochs=n_epochs,
-        validation_data=(
-            (test, test) if predictor == 0 else (test[:-1], [test[:-1], test[1:]])
-        ),
+        validation_data=(Xvals, yvals),
         verbose=1,
         batch_size=256,
         callbacks=callbacks,