diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py index b852bf1b05d657208216f7cfe1b7491f45d56576..d20b91848c4e8b3d245c2f9b66fc8945d17cfc28 100644 --- a/deepof/hypermodels.py +++ b/deepof/hypermodels.py @@ -71,15 +71,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel): def __init__( self, - input_shape, - entropy_reg_weight=0.0, - huber_delta=1.0, - learn_rate=1e-3, - loss="ELBO+MMD", - number_of_components=10, - overlap_loss=False, - predictor=0.0, - prior="standard_normal", + input_shape: tuple, + entropy_reg_weight: float = 0.0, + huber_delta: float = 1.0, + learn_rate: float = 1e-3, + loss: str = "ELBO+MMD", + number_of_components: int = 10, + overlap_loss: bool = False, + phenotype_predictor: float = 0.0, + predictor: float = 0.0, + prior: str = "standard_normal", ): super().__init__() self.input_shape = input_shape @@ -89,6 +90,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): self.loss = loss self.number_of_components = number_of_components self.overlap_loss = overlap_loss + self.pheno_class = phenotype_predictor self.predictor = predictor self.prior = prior @@ -161,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): loss=self.loss, number_of_components=k, overlap_loss=self.overlap_loss, + phenotype_prediction=self.pheno_class, predictor=self.predictor, ).build(self.input_shape)[3:] diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 5eb99b6eb942a621a150cbb7241e286c6df73e7c..f4f4f15119efe2a44e18673a9e1af16d40787838 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -152,6 +152,7 @@ def tune_search( loss=loss, number_of_components=k, overlap_loss=overlap_loss, + phenotype_predictor=pheno_class, predictor=predictor, ) diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index 74587dbfd0f8aec86572241c1b84d6040a1ed35f..0e192a9575053dea2ed293a484189d87c2f08af9 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -65,7 +65,7 @@ def test_get_callbacks( assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler -@settings(max_examples=1, deadline=None) +@settings(max_examples=5, deadline=None) @given( X_train=arrays( dtype=float, @@ -93,7 +93,7 @@ def test_tune_search( ) )[1:] - y_train = tf.random.uniform(shape=(X_train.shape[0],), maxval=1.0) + y_train = tf.random.uniform(shape=(X_train.shape[1],), maxval=1.0) deepof.train_utils.tune_search( data=[X_train, y_train, X_train, y_train],