Commit 4ad26649 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 1c00ad9d
......@@ -193,7 +193,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
next_sequence_prediction=self.next_sequence_prediction,
phenotype_prediction=self.phenotype_prediction,
rule_based_prediction=self.rule_based_prediction,
rule_based_features=self.rule_based_features
rule_based_features=self.rule_based_features,
).build(self.input_shape)[-3]
return gmvaep
......@@ -518,7 +518,7 @@ def tune_search(
if hypermodel == "S2SAE": # pragma: no cover
assert (
next_sequence_prediction == 0.0 and phenotype_prediction == 0.0
next_sequence_prediction == 0.0 and phenotype_prediction == 0.0
), "Prediction branches are only available for variational models. See documentation for more details"
batch_size = 1
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
......
......@@ -82,20 +82,24 @@ def test_get_callbacks(
@settings(max_examples=10, deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
loss=st.one_of(st.just("ELBO"), st.just("MMD"), st.just("ELBO+MMD")),
pheno_class=st.one_of(st.just(1.0), st.just(0.0)),
predictor=st.one_of(st.just(1.0), st.just(0.0)),
next_sequence_prediction=st.one_of(st.just(1.0), st.just(0.0)),
phenotype_prediction=st.one_of(st.just(1.0), st.just(0.0)),
rule_based_prediction=st.one_of(st.just(1.0), st.just(0.0)),
variational=st.one_of(st.just(True), st.just(False)),
)
def test_autoencoder_fitting(
loss,
pheno_class,
predictor,
next_sequence_prediction,
phenotype_prediction,
rule_based_prediction,
variational,
):
X_train = np.random.uniform(-1, 1, [20, 5, 6])
y_train = np.round(np.random.uniform(0, 1, 20))
y_train = np.round(np.random.uniform(0, 1, [20, 1]))
if rule_based_prediction:
y_train = np.concatenate([y_train, np.random.uniform(0, 1, [20, 6])], axis=1)
if predictor:
if next_sequence_prediction:
y_train = y_train[1:]
preprocessed_data = (X_train, y_train, X_train, y_train)
......@@ -118,8 +122,9 @@ def test_autoencoder_fitting(
mmd_warmup=1,
n_components=2,
loss=loss,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
variational=variational,
entropy_samples=10,
entropy_knn=5,
......@@ -147,8 +152,9 @@ def test_autoencoder_fitting(
k=st.integers(min_value=1, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD")),
overlap_loss=st.floats(min_value=0.0, max_value=1.0),
pheno_class=st.floats(min_value=0.0, max_value=1.0),
predictor=st.floats(min_value=0.0, max_value=1.0),
next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
)
def test_tune_search(
X_train,
......@@ -159,8 +165,9 @@ def test_tune_search(
k,
loss,
overlap_loss,
pheno_class,
predictor,
next_sequence_prediction,
phenotype_prediction,
rule_based_prediction,
):
callbacks = list(
deepof.train_utils.get_callbacks(
......@@ -193,8 +200,9 @@ def test_tune_search(
loss=loss,
mmd_warmup_epochs=0,
overlap_loss=overlap_loss,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
project_name="test_run",
callbacks=callbacks,
n_epochs=1,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment