From 9e6df3ca4e9b94f96e3ee5a4bebdd7f407f20e27 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Tue, 13 Apr 2021 19:57:17 +0200 Subject: [PATCH] Added extra branch to main autoencoder for rule_based prediction --- deepof/hypermodels.py | 15 +++++++++------ deepof/train_model.py | 4 ++-- deepof/train_utils.py | 24 +++++++++++++++++------- tests/test_build_hypermodels.py | 2 +- tests/test_train_utils.py | 4 ++-- 5 files changed, 31 insertions(+), 18 deletions(-) diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py index 0c2d063b..9c9e983d 100644 --- a/deepof/hypermodels.py +++ b/deepof/hypermodels.py @@ -102,8 +102,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel): mmd_warmup_epochs: int = 0, number_of_components: int = 10, overlap_loss: float = False, - phenotype_predictor: float = 0.0, - predictor: float = 0.0, + next_sequence_prediction: float = 0.0, + phenotype_prediction: float = 0.0, + rule_based_prediction: float = 0.0, prior: str = "standard_normal", ): super().__init__() @@ -115,8 +116,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel): self.mmd_warmup_epochs = mmd_warmup_epochs self.number_of_components = number_of_components self.overlap_loss = overlap_loss - self.pheno_class = phenotype_predictor - self.predictor = predictor + self.next_sequence_prediction = next_sequence_prediction + self.phenotype_prediction = phenotype_prediction + self.rule_based_prediction = rule_based_prediction self.prior = prior assert ( @@ -186,8 +188,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel): mmd_warmup_epochs=self.mmd_warmup_epochs, number_of_components=k, overlap_loss=self.overlap_loss, - phenotype_prediction=self.pheno_class, - next_sequence_prediction=self.predictor, + next_sequence_prediction=self.next_sequence_prediction, + phenotype_prediction=self.phenotype_prediction, + rule_based_prediction=self.rule_based_prediction, ).build(self.input_shape)[-3] return gmvaep diff --git a/deepof/train_model.py b/deepof/train_model.py index 5b4436cf..b32449b4 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -415,8 +415,8 @@ else: loss=loss, mmd_warmup_epochs=mmd_wu, overlap_loss=overlap_loss, - phenotype_class=pheno_class, - predictor=predictor, + phenotype_prediction=pheno_class, + next_sequence_prediction=predictor, project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()), callbacks=[ tensorboard_callback, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index a8cbc1fc..a3c34910 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -472,8 +472,9 @@ def tune_search( loss: str, mmd_warmup_epochs: int, overlap_loss: float, - phenotype_class: float, - predictor: float, + next_sequence_prediction: float, + phenotype_prediction: float, + rule_based_prediction: float, project_name: str, callbacks: List, n_epochs: int = 30, @@ -517,7 +518,7 @@ def tune_search( if hypermodel == "S2SAE": # pragma: no cover assert ( - predictor == 0.0 and phenotype_class == 0.0 + next_sequence_prediction == 0.0 and phenotype_prediction == 0.0 ), "Prediction branches are only available for variational models. See documentation for more details" batch_size = 1 hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape) @@ -532,8 +533,9 @@ def tune_search( mmd_warmup_epochs=mmd_warmup_epochs, number_of_components=k, overlap_loss=overlap_loss, - phenotype_predictor=phenotype_class, - predictor=predictor, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, ) else: @@ -574,11 +576,19 @@ def tune_search( Xs, ys = [X_train], [X_train] Xvals, yvals = [X_val], [X_val] - if predictor > 0.0: + if next_sequence_prediction > 0.0: Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]] Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]] - if phenotype_class > 0.0: + if phenotype_prediction > 0.0: + ys += [y_train[:, 0]] + yvals += [y_val[:, 0]] + + # Remove the used column (phenotype) from both y arrays + y_train = y_train[:, 1:] + y_val = y_val[:, 1:] + + if rule_based_prediction > 0.0: ys += [y_train] yvals += [y_val] diff --git a/tests/test_build_hypermodels.py b/tests/test_build_hypermodels.py index 47535580..d803789e 100644 --- a/tests/test_build_hypermodels.py +++ b/tests/test_build_hypermodels.py @@ -51,5 +51,5 @@ def test_SEQ_2_SEQ_GMVAE_hypermodel_build( ), loss=loss, number_of_components=number_of_components, - predictor=True, + next_sequence_prediction=True, ).build(hp=HyperParameters()) diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index 53b2d921..e8aa23cf 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -193,8 +193,8 @@ def test_tune_search( loss=loss, mmd_warmup_epochs=0, overlap_loss=overlap_loss, - phenotype_class=pheno_class, - predictor=predictor, + phenotype_prediction=pheno_class, + next_sequence_prediction=predictor, project_name="test_run", callbacks=callbacks, n_epochs=1, -- GitLab