Commit 9e6df3ca authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 1e5ce83e
......@@ -102,8 +102,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
mmd_warmup_epochs: int = 0,
number_of_components: int = 10,
overlap_loss: float = False,
phenotype_predictor: float = 0.0,
predictor: float = 0.0,
next_sequence_prediction: float = 0.0,
phenotype_prediction: float = 0.0,
rule_based_prediction: float = 0.0,
prior: str = "standard_normal",
):
super().__init__()
......@@ -115,8 +116,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
self.mmd_warmup_epochs = mmd_warmup_epochs
self.number_of_components = number_of_components
self.overlap_loss = overlap_loss
self.pheno_class = phenotype_predictor
self.predictor = predictor
self.next_sequence_prediction = next_sequence_prediction
self.phenotype_prediction = phenotype_prediction
self.rule_based_prediction = rule_based_prediction
self.prior = prior
assert (
......@@ -186,8 +188,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
mmd_warmup_epochs=self.mmd_warmup_epochs,
number_of_components=k,
overlap_loss=self.overlap_loss,
phenotype_prediction=self.pheno_class,
next_sequence_prediction=self.predictor,
next_sequence_prediction=self.next_sequence_prediction,
phenotype_prediction=self.phenotype_prediction,
rule_based_prediction=self.rule_based_prediction,
).build(self.input_shape)[-3]
return gmvaep
......@@ -415,8 +415,8 @@ else:
loss=loss,
mmd_warmup_epochs=mmd_wu,
overlap_loss=overlap_loss,
phenotype_class=pheno_class,
predictor=predictor,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
project_name="{}-based_{}_{}".format(input_type, hyp, tune.capitalize()),
callbacks=[
tensorboard_callback,
......
......@@ -472,8 +472,9 @@ def tune_search(
loss: str,
mmd_warmup_epochs: int,
overlap_loss: float,
phenotype_class: float,
predictor: float,
next_sequence_prediction: float,
phenotype_prediction: float,
rule_based_prediction: float,
project_name: str,
callbacks: List,
n_epochs: int = 30,
......@@ -517,7 +518,7 @@ def tune_search(
if hypermodel == "S2SAE": # pragma: no cover
assert (
predictor == 0.0 and phenotype_class == 0.0
next_sequence_prediction == 0.0 and phenotype_prediction == 0.0
), "Prediction branches are only available for variational models. See documentation for more details"
batch_size = 1
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
......@@ -532,8 +533,9 @@ def tune_search(
mmd_warmup_epochs=mmd_warmup_epochs,
number_of_components=k,
overlap_loss=overlap_loss,
phenotype_predictor=phenotype_class,
predictor=predictor,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
)
else:
......@@ -574,11 +576,19 @@ def tune_search(
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
if predictor > 0.0:
if next_sequence_prediction > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if phenotype_class > 0.0:
if phenotype_prediction > 0.0:
ys += [y_train[:, 0]]
yvals += [y_val[:, 0]]
# Remove the used column (phenotype) from both y arrays
y_train = y_train[:, 1:]
y_val = y_val[:, 1:]
if rule_based_prediction > 0.0:
ys += [y_train]
yvals += [y_val]
......
......@@ -51,5 +51,5 @@ def test_SEQ_2_SEQ_GMVAE_hypermodel_build(
),
loss=loss,
number_of_components=number_of_components,
predictor=True,
next_sequence_prediction=True,
).build(hp=HyperParameters())
......@@ -193,8 +193,8 @@ def test_tune_search(
loss=loss,
mmd_warmup_epochs=0,
overlap_loss=overlap_loss,
phenotype_class=pheno_class,
predictor=predictor,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
project_name="test_run",
callbacks=callbacks,
n_epochs=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment