Commit 3e4a5e3d authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed bug in hyperparameter tuning

parent 3817e5c1
Pipeline #86893 failed with stage
in 16 minutes and 2 seconds
......@@ -74,8 +74,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
input_shape: tuple,
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
kl_warmup_epochs: int = 0,
learn_rate: float = 1e-3,
loss: str = "ELBO+MMD",
mmd_warmup_epochs: int = 0,
number_of_components: int = 10,
overlap_loss: float = False,
phenotype_predictor: float = 0.0,
......@@ -86,8 +88,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
self.input_shape = input_shape
self.entropy_reg_weight = entropy_reg_weight
self.huber_delta = huber_delta
self.kl_warmup_epochs = kl_warmup_epochs
self.learn_rate = learn_rate
self.loss = loss
self.mmd_warmup_epochs = mmd_warmup_epochs
self.number_of_components = number_of_components
self.overlap_loss = overlap_loss
self.pheno_class = phenotype_predictor
......@@ -160,7 +164,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
},
entropy_reg_weight=self.entropy_reg_weight,
huber_delta=self.huber_delta,
kl_warmup_epochs=self.kl_warmup_epochs,
loss=self.loss,
mmd_warmup_epochs=self.mmd_warmup_epochs,
number_of_components=k,
overlap_loss=self.overlap_loss,
phenotype_prediction=self.pheno_class,
......
......@@ -415,13 +415,15 @@ else:
bayopt_trials=bayopt_trials,
hypermodel=hyp,
k=k,
kl_warmup_epochs=kl_wu,
loss=loss,
mmd_warmup_epochs=mmd_wu,
overlap_loss=overlap_loss,
pheno_class=pheno_class,
predictor=predictor,
project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
callbacks=[tensorboard_callback, cp_callback, onecycle],
n_replicas=3,
n_replicas=1,
n_epochs=30,
)
......
......@@ -101,13 +101,15 @@ def tune_search(
bayopt_trials: int,
hypermodel: str,
k: int,
kl_warmup_epochs: int,
loss: str,
mmd_warmup_epochs: int,
overlap_loss: float,
pheno_class: float,
predictor: float,
project_name: str,
callbacks: List,
n_epochs: int = 40,
n_epochs: int = 30,
n_replicas: int = 1,
) -> Union[bool, Tuple[Any, Any]]:
"""Define the search space using keras-tuner and bayesian optimization
......@@ -149,7 +151,9 @@ def tune_search(
elif hypermodel == "S2SGMVAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
input_shape=X_train.shape,
kl_warmup_epochs=kl_warmup_epochs,
loss=loss,
mmd_warmup_epochs=mmd_warmup_epochs,
number_of_components=k,
overlap_loss=overlap_loss,
phenotype_predictor=pheno_class,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment