Commit db90399a authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated train_model.py to be compatible with phenotype classification

parent 7c7e6d62
Pipeline #86829 passed with stage
in 19 minutes and 35 seconds
...@@ -71,15 +71,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -71,15 +71,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
def __init__( def __init__(
self, self,
input_shape, input_shape: tuple,
entropy_reg_weight=0.0, entropy_reg_weight: float = 0.0,
huber_delta=1.0, huber_delta: float = 1.0,
learn_rate=1e-3, learn_rate: float = 1e-3,
loss="ELBO+MMD", loss: str = "ELBO+MMD",
number_of_components=10, number_of_components: int = 10,
overlap_loss=False, overlap_loss: bool = False,
predictor=0.0, phenotype_predictor: float = 0.0,
prior="standard_normal", predictor: float = 0.0,
prior: str = "standard_normal",
): ):
super().__init__() super().__init__()
self.input_shape = input_shape self.input_shape = input_shape
...@@ -89,6 +90,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -89,6 +90,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
self.loss = loss self.loss = loss
self.number_of_components = number_of_components self.number_of_components = number_of_components
self.overlap_loss = overlap_loss self.overlap_loss = overlap_loss
self.pheno_class = phenotype_predictor
self.predictor = predictor self.predictor = predictor
self.prior = prior self.prior = prior
...@@ -161,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -161,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
loss=self.loss, loss=self.loss,
number_of_components=k, number_of_components=k,
overlap_loss=self.overlap_loss, overlap_loss=self.overlap_loss,
phenotype_prediction=self.pheno_class,
predictor=self.predictor, predictor=self.predictor,
).build(self.input_shape)[3:] ).build(self.input_shape)[3:]
......
...@@ -152,6 +152,7 @@ def tune_search( ...@@ -152,6 +152,7 @@ def tune_search(
loss=loss, loss=loss,
number_of_components=k, number_of_components=k,
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
phenotype_predictor=pheno_class,
predictor=predictor, predictor=predictor,
) )
......
...@@ -65,7 +65,7 @@ def test_get_callbacks( ...@@ -65,7 +65,7 @@ def test_get_callbacks(
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
@settings(max_examples=1, deadline=None) @settings(max_examples=5, deadline=None)
@given( @given(
X_train=arrays( X_train=arrays(
dtype=float, dtype=float,
...@@ -93,7 +93,7 @@ def test_tune_search( ...@@ -93,7 +93,7 @@ def test_tune_search(
) )
)[1:] )[1:]
y_train = tf.random.uniform(shape=(X_train.shape[0],), maxval=1.0) y_train = tf.random.uniform(shape=(X_train.shape[1],), maxval=1.0)
deepof.train_utils.tune_search( deepof.train_utils.tune_search(
data=[X_train, y_train, X_train, y_train], data=[X_train, y_train, X_train, y_train],
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment