From 96d787f5a8cf13d7cf650a3fe88bdfc16f548b51 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 13 Apr 2021 20:40:21 +0200
Subject: [PATCH] Added extra branch to main autoencoder for rule_based
 prediction

---
 deepof/train_model.py     |  4 ++--
 deepof/train_utils.py     | 15 +++++++++------
 tests/test_train_utils.py | 20 ++++++++++++--------
 3 files changed, 23 insertions(+), 16 deletions(-)

diff --git a/deepof/train_model.py b/deepof/train_model.py
index b32449b4..e7fe0010 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -397,8 +397,8 @@ else:
         variational=variational,
         entropy_samples=entropy_samples,
         entropy_knn=entropy_knn,
-        phenotype_class=pheno_class,
-        predictor=predictor,
+        phenotype_prediction=pheno_class,
+        next_sequence_prediction=predictor,
         loss=loss,
         logparam=logparam,
         outpath=output_path,
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 951e0cf6..1230498a 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -69,8 +69,9 @@ def get_callbacks(
     X_train: np.array,
     batch_size: int,
     variational: bool,
-    phenotype_class: float,
-    predictor: float,
+    phenotype_prediction: float,
+    next_sequence_prediction: float,
+    rule_based_prediction: float,
     loss: str,
     X_val: np.array = None,
     cp: bool = False,
@@ -97,8 +98,9 @@ def get_callbacks(
 
     run_ID = "{}{}{}{}{}{}{}_{}".format(
         ("GMVAE" if variational else "AE"),
-        ("_Pred={}".format(predictor) if predictor > 0 and variational else ""),
-        ("_Pheno={}".format(phenotype_class) if phenotype_class > 0 else ""),
+        ("_NextSeqPred={}".format(next_sequence_prediction) if next_sequence_prediction > 0 and variational else ""),
+        ("_PhenoPred={}".format(phenotype_prediction) if phenotype_prediction > 0 else ""),
+        ("_RuleBasedPred={}".format(rule_based_prediction) if rule_based_prediction > 0 else ""),
         ("_loss={}".format(loss) if variational else ""),
         ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
         ("_k={}".format(logparam["k"]) if logparam is not None else ""),
@@ -295,8 +297,9 @@ def autoencoder_fitting(
         batch_size=batch_size,
         cp=save_checkpoints,
         variational=variational,
-        phenotype_class=phenotype_prediction,
-        predictor=next_sequence_prediction,
+        next_sequence_prediction=next_sequence_prediction,
+        phenotype_prediction=phenotype_prediction,
+        rule_based_prediction=rule_based_prediction,
         loss=loss,
         entropy_samples=entropy_samples,
         entropy_knn=entropy_knn,
diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py
index b30d3f67..47c34232 100644
--- a/tests/test_train_utils.py
+++ b/tests/test_train_utils.py
@@ -42,24 +42,27 @@ def test_load_treatments():
     ),
     batch_size=st.integers(min_value=128, max_value=512),
     loss=st.one_of(st.just("test_A"), st.just("test_B")),
-    predictor=st.floats(min_value=0.0, max_value=1.0),
-    pheno_class=st.floats(min_value=0.0, max_value=1.0),
+    next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
+    phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
+    rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
     variational=st.booleans(),
 )
 def test_get_callbacks(
     X_train,
     batch_size,
     variational,
-    predictor,
-    pheno_class,
+    next_sequence_prediction,
+    phenotype_prediction,
+    rule_based_prediction,
     loss,
 ):
     callbacks = deepof.train_utils.get_callbacks(
         X_train=X_train,
         batch_size=batch_size,
         variational=variational,
-        phenotype_class=pheno_class,
-        predictor=predictor,
+        next_sequence_prediction=next_sequence_prediction,
+        phenotype_prediction=phenotype_prediction,
+        rule_based_prediction=rule_based_prediction,
         loss=loss,
         X_val=X_train,
         cp=True,
@@ -174,8 +177,9 @@ def test_tune_search(
             X_train=X_train,
             batch_size=batch_size,
             variational=(hypermodel == "S2SGMVAE"),
-            phenotype_class=0,
-            predictor=predictor,
+            next_sequence_prediction=next_sequence_prediction,
+            phenotype_prediction=phenotype_prediction,
+            rule_based_prediction=rule_based_prediction,
             loss=loss,
             X_val=X_train,
             cp=False,
-- 
GitLab