diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index b852bf1b05d657208216f7cfe1b7491f45d56576..d20b91848c4e8b3d245c2f9b66fc8945d17cfc28 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -71,15 +71,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
 
     def __init__(
         self,
-        input_shape,
-        entropy_reg_weight=0.0,
-        huber_delta=1.0,
-        learn_rate=1e-3,
-        loss="ELBO+MMD",
-        number_of_components=10,
-        overlap_loss=False,
-        predictor=0.0,
-        prior="standard_normal",
+        input_shape: tuple,
+        entropy_reg_weight: float = 0.0,
+        huber_delta: float = 1.0,
+        learn_rate: float = 1e-3,
+        loss: str = "ELBO+MMD",
+        number_of_components: int = 10,
+        overlap_loss: bool = False,
+        phenotype_predictor: float = 0.0,
+        predictor: float = 0.0,
+        prior: str = "standard_normal",
     ):
         super().__init__()
         self.input_shape = input_shape
@@ -89,6 +90,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         self.loss = loss
         self.number_of_components = number_of_components
         self.overlap_loss = overlap_loss
+        self.pheno_class = phenotype_predictor
         self.predictor = predictor
         self.prior = prior
 
@@ -161,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             loss=self.loss,
             number_of_components=k,
             overlap_loss=self.overlap_loss,
+            phenotype_prediction=self.pheno_class,
             predictor=self.predictor,
         ).build(self.input_shape)[3:]
 
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 5eb99b6eb942a621a150cbb7241e286c6df73e7c..f4f4f15119efe2a44e18673a9e1af16d40787838 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -152,6 +152,7 @@ def tune_search(
             loss=loss,
             number_of_components=k,
             overlap_loss=overlap_loss,
+            phenotype_predictor=pheno_class,
             predictor=predictor,
         )
 
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index 74587dbfd0f8aec86572241c1b84d6040a1ed35f..0e192a9575053dea2ed293a484189d87c2f08af9 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -65,7 +65,7 @@ def test_get_callbacks(
     assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
 
 
-@settings(max_examples=1, deadline=None)
+@settings(max_examples=5, deadline=None)
 @given(
     X_train=arrays(
         dtype=float,
@@ -93,7 +93,7 @@ def test_tune_search(
         )
     )[1:]
 
-    y_train = tf.random.uniform(shape=(X_train.shape[0],), maxval=1.0)
+    y_train = tf.random.uniform(shape=(X_train.shape[1],), maxval=1.0)
 
     deepof.train_utils.tune_search(
         data=[X_train, y_train, X_train, y_train],