Commit 4ad26649 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 1c00ad9d
...@@ -193,7 +193,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -193,7 +193,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
next_sequence_prediction=self.next_sequence_prediction, next_sequence_prediction=self.next_sequence_prediction,
phenotype_prediction=self.phenotype_prediction, phenotype_prediction=self.phenotype_prediction,
rule_based_prediction=self.rule_based_prediction, rule_based_prediction=self.rule_based_prediction,
rule_based_features=self.rule_based_features rule_based_features=self.rule_based_features,
).build(self.input_shape)[-3] ).build(self.input_shape)[-3]
return gmvaep return gmvaep
...@@ -518,7 +518,7 @@ def tune_search( ...@@ -518,7 +518,7 @@ def tune_search(
if hypermodel == "S2SAE": # pragma: no cover if hypermodel == "S2SAE": # pragma: no cover
assert ( assert (
next_sequence_prediction == 0.0 and phenotype_prediction == 0.0 next_sequence_prediction == 0.0 and phenotype_prediction == 0.0
), "Prediction branches are only available for variational models. See documentation for more details" ), "Prediction branches are only available for variational models. See documentation for more details"
batch_size = 1 batch_size = 1
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape) hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
......
...@@ -82,20 +82,24 @@ def test_get_callbacks( ...@@ -82,20 +82,24 @@ def test_get_callbacks(
@settings(max_examples=10, deadline=None, suppress_health_check=[HealthCheck.too_slow]) @settings(max_examples=10, deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given( @given(
loss=st.one_of(st.just("ELBO"), st.just("MMD"), st.just("ELBO+MMD")), loss=st.one_of(st.just("ELBO"), st.just("MMD"), st.just("ELBO+MMD")),
pheno_class=st.one_of(st.just(1.0), st.just(0.0)), next_sequence_prediction=st.one_of(st.just(1.0), st.just(0.0)),
predictor=st.one_of(st.just(1.0), st.just(0.0)), phenotype_prediction=st.one_of(st.just(1.0), st.just(0.0)),
rule_based_prediction=st.one_of(st.just(1.0), st.just(0.0)),
variational=st.one_of(st.just(True), st.just(False)), variational=st.one_of(st.just(True), st.just(False)),
) )
def test_autoencoder_fitting( def test_autoencoder_fitting(
loss, loss,
pheno_class, next_sequence_prediction,
predictor, phenotype_prediction,
rule_based_prediction,
variational, variational,
): ):
X_train = np.random.uniform(-1, 1, [20, 5, 6]) X_train = np.random.uniform(-1, 1, [20, 5, 6])
y_train = np.round(np.random.uniform(0, 1, 20)) y_train = np.round(np.random.uniform(0, 1, [20, 1]))
if rule_based_prediction:
y_train = np.concatenate([y_train, np.random.uniform(0, 1, [20, 6])], axis=1)
if predictor: if next_sequence_prediction:
y_train = y_train[1:] y_train = y_train[1:]
preprocessed_data = (X_train, y_train, X_train, y_train) preprocessed_data = (X_train, y_train, X_train, y_train)
...@@ -118,8 +122,9 @@ def test_autoencoder_fitting( ...@@ -118,8 +122,9 @@ def test_autoencoder_fitting(
mmd_warmup=1, mmd_warmup=1,
n_components=2, n_components=2,
loss=loss, loss=loss,
phenotype_prediction=pheno_class, next_sequence_prediction=next_sequence_prediction,
next_sequence_prediction=predictor, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
variational=variational, variational=variational,
entropy_samples=10, entropy_samples=10,
entropy_knn=5, entropy_knn=5,
...@@ -147,8 +152,9 @@ def test_autoencoder_fitting( ...@@ -147,8 +152,9 @@ def test_autoencoder_fitting(
k=st.integers(min_value=1, max_value=10), k=st.integers(min_value=1, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD")), loss=st.one_of(st.just("ELBO"), st.just("MMD")),
overlap_loss=st.floats(min_value=0.0, max_value=1.0), overlap_loss=st.floats(min_value=0.0, max_value=1.0),
pheno_class=st.floats(min_value=0.0, max_value=1.0), next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
predictor=st.floats(min_value=0.0, max_value=1.0), phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
) )
def test_tune_search( def test_tune_search(
X_train, X_train,
...@@ -159,8 +165,9 @@ def test_tune_search( ...@@ -159,8 +165,9 @@ def test_tune_search(
k, k,
loss, loss,
overlap_loss, overlap_loss,
pheno_class, next_sequence_prediction,
predictor, phenotype_prediction,
rule_based_prediction,
): ):
callbacks = list( callbacks = list(
deepof.train_utils.get_callbacks( deepof.train_utils.get_callbacks(
...@@ -193,8 +200,9 @@ def test_tune_search( ...@@ -193,8 +200,9 @@ def test_tune_search(
loss=loss, loss=loss,
mmd_warmup_epochs=0, mmd_warmup_epochs=0,
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
phenotype_prediction=pheno_class, next_sequence_prediction=next_sequence_prediction,
next_sequence_prediction=predictor, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
project_name="test_run", project_name="test_run",
callbacks=callbacks, callbacks=callbacks,
n_epochs=1, n_epochs=1,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment