From 9e6df3ca4e9b94f96e3ee5a4bebdd7f407f20e27 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 13 Apr 2021 19:57:17 +0200
Subject: [PATCH] Added extra branch to main autoencoder for rule_based
 prediction

---
 deepof/hypermodels.py           | 15 +++++++++------
 deepof/train_model.py           |  4 ++--
 deepof/train_utils.py           | 24 +++++++++++++++++-------
 tests/test_build_hypermodels.py |  2 +-
 tests/test_train_utils.py       |  4 ++--
 5 files changed, 31 insertions(+), 18 deletions(-)

diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index 0c2d063b..9c9e983d 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -102,8 +102,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         mmd_warmup_epochs: int = 0,
         number_of_components: int = 10,
         overlap_loss: float = False,
-        phenotype_predictor: float = 0.0,
-        predictor: float = 0.0,
+        next_sequence_prediction: float = 0.0,
+        phenotype_prediction: float = 0.0,
+        rule_based_prediction: float = 0.0,
         prior: str = "standard_normal",
     ):
         super().__init__()
@@ -115,8 +116,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         self.mmd_warmup_epochs = mmd_warmup_epochs
         self.number_of_components = number_of_components
         self.overlap_loss = overlap_loss
-        self.pheno_class = phenotype_predictor
-        self.predictor = predictor
+        self.next_sequence_prediction = next_sequence_prediction
+        self.phenotype_prediction = phenotype_prediction
+        self.rule_based_prediction = rule_based_prediction
         self.prior = prior
 
         assert (
@@ -186,8 +188,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             mmd_warmup_epochs=self.mmd_warmup_epochs,
             number_of_components=k,
             overlap_loss=self.overlap_loss,
-            phenotype_prediction=self.pheno_class,
-            next_sequence_prediction=self.predictor,
+            next_sequence_prediction=self.next_sequence_prediction,
+            phenotype_prediction=self.phenotype_prediction,
+            rule_based_prediction=self.rule_based_prediction,
         ).build(self.input_shape)[-3]
 
         return gmvaep
diff --git a/deepof/train_model.py b/deepof/train_model.py
index 5b4436cf..b32449b4 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -415,8 +415,8 @@ else:
         loss=loss,
         mmd_warmup_epochs=mmd_wu,
         overlap_loss=overlap_loss,
-        phenotype_class=pheno_class,
-        predictor=predictor,
+        phenotype_prediction=pheno_class,
+        next_sequence_prediction=predictor,
         project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
         callbacks=[
             tensorboard_callback,
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index a8cbc1fc..a3c34910 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -472,8 +472,9 @@ def tune_search(
     loss: str,
     mmd_warmup_epochs: int,
     overlap_loss: float,
-    phenotype_class: float,
-    predictor: float,
+    next_sequence_prediction: float,
+    phenotype_prediction: float,
+    rule_based_prediction: float,
     project_name: str,
     callbacks: List,
     n_epochs: int = 30,
@@ -517,7 +518,7 @@ def tune_search(
 
     if hypermodel == "S2SAE":  # pragma: no cover
         assert (
-            predictor == 0.0 and phenotype_class == 0.0
+                next_sequence_prediction == 0.0 and phenotype_prediction == 0.0
         ), "Prediction branches are only available for variational models. See documentation for more details"
         batch_size = 1
         hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
@@ -532,8 +533,9 @@ def tune_search(
             mmd_warmup_epochs=mmd_warmup_epochs,
             number_of_components=k,
             overlap_loss=overlap_loss,
-            phenotype_predictor=phenotype_class,
-            predictor=predictor,
+            next_sequence_prediction=next_sequence_prediction,
+            phenotype_prediction=phenotype_prediction,
+            rule_based_prediction=rule_based_prediction,
         )
 
     else:
@@ -574,11 +576,19 @@ def tune_search(
     Xs, ys = [X_train], [X_train]
     Xvals, yvals = [X_val], [X_val]
 
-    if predictor > 0.0:
+    if next_sequence_prediction > 0.0:
         Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
         Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
 
-    if phenotype_class > 0.0:
+    if phenotype_prediction > 0.0:
+        ys += [y_train[:, 0]]
+        yvals += [y_val[:, 0]]
+
+        # Remove the used column (phenotype) from both y arrays
+        y_train = y_train[:, 1:]
+        y_val = y_val[:, 1:]
+
+    if rule_based_prediction > 0.0:
         ys += [y_train]
         yvals += [y_val]
 
diff --git a/tests/test_build_hypermodels.py b/tests/test_build_hypermodels.py
index 47535580..d803789e 100644
--- a/tests/test_build_hypermodels.py
+++ b/tests/test_build_hypermodels.py
@@ -51,5 +51,5 @@ def test_SEQ_2_SEQ_GMVAE_hypermodel_build(
         ),
         loss=loss,
         number_of_components=number_of_components,
-        predictor=True,
+        next_sequence_prediction=True,
     ).build(hp=HyperParameters())
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index 53b2d921..e8aa23cf 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -193,8 +193,8 @@ def test_tune_search(
         loss=loss,
         mmd_warmup_epochs=0,
         overlap_loss=overlap_loss,
-        phenotype_class=pheno_class,
-        predictor=predictor,
+        phenotype_prediction=pheno_class,
+        next_sequence_prediction=predictor,
         project_name="test_run",
         callbacks=callbacks,
         n_epochs=1,
-- 
GitLab