Commit 1c00ad9d authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 9e6df3ca
......@@ -105,6 +105,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
next_sequence_prediction: float = 0.0,
phenotype_prediction: float = 0.0,
rule_based_prediction: float = 0.0,
rule_based_features: int = 6,
prior: str = "standard_normal",
):
super().__init__()
......@@ -119,6 +120,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
self.next_sequence_prediction = next_sequence_prediction
self.phenotype_prediction = phenotype_prediction
self.rule_based_prediction = rule_based_prediction
self.rule_based_features = rule_based_features
self.prior = prior
assert (
......@@ -191,6 +193,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
next_sequence_prediction=self.next_sequence_prediction,
phenotype_prediction=self.phenotype_prediction,
rule_based_prediction=self.rule_based_prediction,
rule_based_features=self.rule_based_features
).build(self.input_shape)[-3]
return gmvaep
......@@ -536,6 +536,9 @@ def tune_search(
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
rule_based_features=(
y_train.shape[1] if not phenotype_prediction else y_train.shape[1] - 1
),
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment