From db90399afdab351d7a54aafbba74bcb3ee153e81 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Tue, 17 Nov 2020 11:27:29 +0100 Subject: [PATCH] Updated train_model.py to be compatible with phenotype classification --- deepof/hypermodels.py | 21 ++++++++++++--------- deepof/train_utils.py | 1 + tests/test_train_utils.py | 4 ++-- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py index b852bf1b..d20b9184 100644 --- a/deepof/hypermodels.py +++ b/deepof/hypermodels.py @@ -71,15 +71,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel): def __init__( self, - input_shape, - entropy_reg_weight=0.0, - huber_delta=1.0, - learn_rate=1e-3, - loss="ELBO+MMD", - number_of_components=10, - overlap_loss=False, - predictor=0.0, - prior="standard_normal", + input_shape: tuple, + entropy_reg_weight: float = 0.0, + huber_delta: float = 1.0, + learn_rate: float = 1e-3, + loss: str = "ELBO+MMD", + number_of_components: int = 10, + overlap_loss: bool = False, + phenotype_predictor: float = 0.0, + predictor: float = 0.0, + prior: str = "standard_normal", ): super().__init__() self.input_shape = input_shape @@ -89,6 +90,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): self.loss = loss self.number_of_components = number_of_components self.overlap_loss = overlap_loss + self.pheno_class = phenotype_predictor self.predictor = predictor self.prior = prior @@ -161,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): loss=self.loss, number_of_components=k, overlap_loss=self.overlap_loss, + phenotype_prediction=self.pheno_class, predictor=self.predictor, ).build(self.input_shape)[3:] diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 5eb99b6e..f4f4f151 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -152,6 +152,7 @@ def tune_search( loss=loss, number_of_components=k, overlap_loss=overlap_loss, + phenotype_predictor=pheno_class, predictor=predictor, ) diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index 74587dbf..0e192a95 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -65,7 +65,7 @@ def test_get_callbacks( assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler -@settings(max_examples=1, deadline=None) +@settings(max_examples=5, deadline=None) @given( X_train=arrays( dtype=float, @@ -93,7 +93,7 @@ def test_tune_search( ) )[1:] - y_train = tf.random.uniform(shape=(X_train.shape[0],), maxval=1.0) + y_train = tf.random.uniform(shape=(X_train.shape[1],), maxval=1.0) deepof.train_utils.tune_search( data=[X_train, y_train, X_train, y_train], -- GitLab