diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index 0c2d063be3bb0c76c348bcbd6c872e428feec9c0..9c9e983d7f2869061808883a2249866e55e8343d 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -102,8 +102,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         mmd_warmup_epochs: int = 0,
         number_of_components: int = 10,
         overlap_loss: float = False,
-        phenotype_predictor: float = 0.0,
-        predictor: float = 0.0,
+        next_sequence_prediction: float = 0.0,
+        phenotype_prediction: float = 0.0,
+        rule_based_prediction: float = 0.0,
         prior: str = "standard_normal",
     ):
         super().__init__()
@@ -115,8 +116,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         self.mmd_warmup_epochs = mmd_warmup_epochs
         self.number_of_components = number_of_components
         self.overlap_loss = overlap_loss
-        self.pheno_class = phenotype_predictor
-        self.predictor = predictor
+        self.next_sequence_prediction = next_sequence_prediction
+        self.phenotype_prediction = phenotype_prediction
+        self.rule_based_prediction = rule_based_prediction
         self.prior = prior
 
         assert (
@@ -186,8 +188,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             mmd_warmup_epochs=self.mmd_warmup_epochs,
             number_of_components=k,
             overlap_loss=self.overlap_loss,
-            phenotype_prediction=self.pheno_class,
-            next_sequence_prediction=self.predictor,
+            next_sequence_prediction=self.next_sequence_prediction,
+            phenotype_prediction=self.phenotype_prediction,
+            rule_based_prediction=self.rule_based_prediction,
         ).build(self.input_shape)[-3]
 
         return gmvaep
diff --git a/deepof/train_model.py b/deepof/train_model.py
index 5b4436cfbc393ed8070a4049bb8978fbdd2addb3..b32449b464450f5848579ef944c77c38344a5c4f 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -415,8 +415,8 @@ else:
         loss=loss,
         mmd_warmup_epochs=mmd_wu,
         overlap_loss=overlap_loss,
-        phenotype_class=pheno_class,
-        predictor=predictor,
+        phenotype_prediction=pheno_class,
+        next_sequence_prediction=predictor,
         project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
         callbacks=[
             tensorboard_callback,
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index a8cbc1fc3979fef177e989047f562c8f399b21dc..a3c3491001fbc525beaac621fc43fe4e38df19d6 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -472,8 +472,9 @@ def tune_search(
     loss: str,
     mmd_warmup_epochs: int,
     overlap_loss: float,
-    phenotype_class: float,
-    predictor: float,
+    next_sequence_prediction: float,
+    phenotype_prediction: float,
+    rule_based_prediction: float,
     project_name: str,
     callbacks: List,
     n_epochs: int = 30,
@@ -517,7 +518,7 @@ def tune_search(
 
     if hypermodel == "S2SAE":  # pragma: no cover
         assert (
-            predictor == 0.0 and phenotype_class == 0.0
+                next_sequence_prediction == 0.0 and phenotype_prediction == 0.0
         ), "Prediction branches are only available for variational models. See documentation for more details"
         batch_size = 1
         hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
@@ -532,8 +533,9 @@ def tune_search(
             mmd_warmup_epochs=mmd_warmup_epochs,
             number_of_components=k,
             overlap_loss=overlap_loss,
-            phenotype_predictor=phenotype_class,
-            predictor=predictor,
+            next_sequence_prediction=next_sequence_prediction,
+            phenotype_prediction=phenotype_prediction,
+            rule_based_prediction=rule_based_prediction,
         )
 
     else:
@@ -574,11 +576,19 @@ def tune_search(
     Xs, ys = [X_train], [X_train]
     Xvals, yvals = [X_val], [X_val]
 
-    if predictor > 0.0:
+    if next_sequence_prediction > 0.0:
         Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
         Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
 
-    if phenotype_class > 0.0:
+    if phenotype_prediction > 0.0:
+        ys += [y_train[:, 0]]
+        yvals += [y_val[:, 0]]
+
+        # Remove the used column (phenotype) from both y arrays
+        y_train = y_train[:, 1:]
+        y_val = y_val[:, 1:]
+
+    if rule_based_prediction > 0.0:
         ys += [y_train]
         yvals += [y_val]
 
diff --git a/tests/test_build_hypermodels.py b/tests/test_build_hypermodels.py
index 4753558024815cb237ac26815511eabab132866e..d803789e8cf9d0bbaaae586b386f7b9cb7be361c 100644
--- a/tests/test_build_hypermodels.py
+++ b/tests/test_build_hypermodels.py
@@ -51,5 +51,5 @@ def test_SEQ_2_SEQ_GMVAE_hypermodel_build(
         ),
         loss=loss,
         number_of_components=number_of_components,
-        predictor=True,
+        next_sequence_prediction=True,
     ).build(hp=HyperParameters())
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index 53b2d921a1a04642f57db9b24cc125f3fe69d5c8..e8aa23cfc8cf2eea52757c4638f89405d8a33ddd 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -193,8 +193,8 @@ def test_tune_search(
         loss=loss,
         mmd_warmup_epochs=0,
         overlap_loss=overlap_loss,
-        phenotype_class=pheno_class,
-        predictor=predictor,
+        phenotype_prediction=pheno_class,
+        next_sequence_prediction=predictor,
         project_name="test_run",
         callbacks=callbacks,
         n_epochs=1,