diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py index 0c2d063be3bb0c76c348bcbd6c872e428feec9c0..9c9e983d7f2869061808883a2249866e55e8343d 100644 --- a/deepof/hypermodels.py +++ b/deepof/hypermodels.py @@ -102,8 +102,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel): mmd_warmup_epochs: int = 0, number_of_components: int = 10, overlap_loss: float = False, - phenotype_predictor: float = 0.0, - predictor: float = 0.0, + next_sequence_prediction: float = 0.0, + phenotype_prediction: float = 0.0, + rule_based_prediction: float = 0.0, prior: str = "standard_normal", ): super().__init__() @@ -115,8 +116,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel): self.mmd_warmup_epochs = mmd_warmup_epochs self.number_of_components = number_of_components self.overlap_loss = overlap_loss - self.pheno_class = phenotype_predictor - self.predictor = predictor + self.next_sequence_prediction = next_sequence_prediction + self.phenotype_prediction = phenotype_prediction + self.rule_based_prediction = rule_based_prediction self.prior = prior assert ( @@ -186,8 +188,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel): mmd_warmup_epochs=self.mmd_warmup_epochs, number_of_components=k, overlap_loss=self.overlap_loss, - phenotype_prediction=self.pheno_class, - next_sequence_prediction=self.predictor, + next_sequence_prediction=self.next_sequence_prediction, + phenotype_prediction=self.phenotype_prediction, + rule_based_prediction=self.rule_based_prediction, ).build(self.input_shape)[-3] return gmvaep diff --git a/deepof/train_model.py b/deepof/train_model.py index 5b4436cfbc393ed8070a4049bb8978fbdd2addb3..b32449b464450f5848579ef944c77c38344a5c4f 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -415,8 +415,8 @@ else: loss=loss, mmd_warmup_epochs=mmd_wu, overlap_loss=overlap_loss, - phenotype_class=pheno_class, - predictor=predictor, + phenotype_prediction=pheno_class, + next_sequence_prediction=predictor, project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()), callbacks=[ tensorboard_callback, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index a8cbc1fc3979fef177e989047f562c8f399b21dc..a3c3491001fbc525beaac621fc43fe4e38df19d6 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -472,8 +472,9 @@ def tune_search( loss: str, mmd_warmup_epochs: int, overlap_loss: float, - phenotype_class: float, - predictor: float, + next_sequence_prediction: float, + phenotype_prediction: float, + rule_based_prediction: float, project_name: str, callbacks: List, n_epochs: int = 30, @@ -517,7 +518,7 @@ def tune_search( if hypermodel == "S2SAE": # pragma: no cover assert ( - predictor == 0.0 and phenotype_class == 0.0 + next_sequence_prediction == 0.0 and phenotype_prediction == 0.0 ), "Prediction branches are only available for variational models. See documentation for more details" batch_size = 1 hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape) @@ -532,8 +533,9 @@ def tune_search( mmd_warmup_epochs=mmd_warmup_epochs, number_of_components=k, overlap_loss=overlap_loss, - phenotype_predictor=phenotype_class, - predictor=predictor, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, ) else: @@ -574,11 +576,19 @@ def tune_search( Xs, ys = [X_train], [X_train] Xvals, yvals = [X_val], [X_val] - if predictor > 0.0: + if next_sequence_prediction > 0.0: Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]] Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]] - if phenotype_class > 0.0: + if phenotype_prediction > 0.0: + ys += [y_train[:, 0]] + yvals += [y_val[:, 0]] + + # Remove the used column (phenotype) from both y arrays + y_train = y_train[:, 1:] + y_val = y_val[:, 1:] + + if rule_based_prediction > 0.0: ys += [y_train] yvals += [y_val] diff --git a/tests/test_build_hypermodels.py b/tests/test_build_hypermodels.py index 4753558024815cb237ac26815511eabab132866e..d803789e8cf9d0bbaaae586b386f7b9cb7be361c 100644 --- a/tests/test_build_hypermodels.py +++ b/tests/test_build_hypermodels.py @@ -51,5 +51,5 @@ def test_SEQ_2_SEQ_GMVAE_hypermodel_build( ), loss=loss, number_of_components=number_of_components, - predictor=True, + next_sequence_prediction=True, ).build(hp=HyperParameters()) diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index 53b2d921a1a04642f57db9b24cc125f3fe69d5c8..e8aa23cfc8cf2eea52757c4638f89405d8a33ddd 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -193,8 +193,8 @@ def test_tune_search( loss=loss, mmd_warmup_epochs=0, overlap_loss=overlap_loss, - phenotype_class=pheno_class, - predictor=predictor, + phenotype_prediction=pheno_class, + next_sequence_prediction=predictor, project_name="test_run", callbacks=callbacks, n_epochs=1,