From db90399afdab351d7a54aafbba74bcb3ee153e81 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 17 Nov 2020 11:27:29 +0100
Subject: [PATCH] Updated train_model.py to be compatible with phenotype
 classification

---
 deepof/hypermodels.py     | 21 ++++++++++++---------
 deepof/train_utils.py     |  1 +
 tests/test_train_utils.py |  4 ++--
 3 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index b852bf1b..d20b9184 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -71,15 +71,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
 
     def __init__(
         self,
-        input_shape,
-        entropy_reg_weight=0.0,
-        huber_delta=1.0,
-        learn_rate=1e-3,
-        loss="ELBO+MMD",
-        number_of_components=10,
-        overlap_loss=False,
-        predictor=0.0,
-        prior="standard_normal",
+        input_shape: tuple,
+        entropy_reg_weight: float = 0.0,
+        huber_delta: float = 1.0,
+        learn_rate: float = 1e-3,
+        loss: str = "ELBO+MMD",
+        number_of_components: int = 10,
+        overlap_loss: bool = False,
+        phenotype_predictor: float = 0.0,
+        predictor: float = 0.0,
+        prior: str = "standard_normal",
     ):
         super().__init__()
         self.input_shape = input_shape
@@ -89,6 +90,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         self.loss = loss
         self.number_of_components = number_of_components
         self.overlap_loss = overlap_loss
+        self.pheno_class = phenotype_predictor
         self.predictor = predictor
         self.prior = prior
 
@@ -161,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             loss=self.loss,
             number_of_components=k,
             overlap_loss=self.overlap_loss,
+            phenotype_prediction=self.pheno_class,
             predictor=self.predictor,
         ).build(self.input_shape)[3:]
 
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 5eb99b6e..f4f4f151 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -152,6 +152,7 @@ def tune_search(
             loss=loss,
             number_of_components=k,
             overlap_loss=overlap_loss,
+            phenotype_predictor=pheno_class,
             predictor=predictor,
         )
 
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index 74587dbf..0e192a95 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -65,7 +65,7 @@ def test_get_callbacks(
     assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
 
 
-@settings(max_examples=1, deadline=None)
+@settings(max_examples=5, deadline=None)
 @given(
     X_train=arrays(
         dtype=float,
@@ -93,7 +93,7 @@ def test_tune_search(
         )
     )[1:]
 
-    y_train = tf.random.uniform(shape=(X_train.shape[0],), maxval=1.0)
+    y_train = tf.random.uniform(shape=(X_train.shape[1],), maxval=1.0)
 
     deepof.train_utils.tune_search(
         data=[X_train, y_train, X_train, y_train],
-- 
GitLab