Skip to content
Snippets Groups Projects
Commit 96d787f5 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 4ad26649
Branches
No related tags found
No related merge requests found
...@@ -397,8 +397,8 @@ else: ...@@ -397,8 +397,8 @@ else:
variational=variational, variational=variational,
entropy_samples=entropy_samples, entropy_samples=entropy_samples,
entropy_knn=entropy_knn, entropy_knn=entropy_knn,
phenotype_class=pheno_class, phenotype_prediction=pheno_class,
predictor=predictor, next_sequence_prediction=predictor,
loss=loss, loss=loss,
logparam=logparam, logparam=logparam,
outpath=output_path, outpath=output_path,
......
...@@ -69,8 +69,9 @@ def get_callbacks( ...@@ -69,8 +69,9 @@ def get_callbacks(
X_train: np.array, X_train: np.array,
batch_size: int, batch_size: int,
variational: bool, variational: bool,
phenotype_class: float, phenotype_prediction: float,
predictor: float, next_sequence_prediction: float,
rule_based_prediction: float,
loss: str, loss: str,
X_val: np.array = None, X_val: np.array = None,
cp: bool = False, cp: bool = False,
...@@ -97,8 +98,9 @@ def get_callbacks( ...@@ -97,8 +98,9 @@ def get_callbacks(
run_ID = "{}{}{}{}{}{}{}_{}".format( run_ID = "{}{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"), ("GMVAE" if variational else "AE"),
("_Pred={}".format(predictor) if predictor > 0 and variational else ""), ("_NextSeqPred={}".format(next_sequence_prediction) if next_sequence_prediction > 0 and variational else ""),
("_Pheno={}".format(phenotype_class) if phenotype_class > 0 else ""), ("_PhenoPred={}".format(phenotype_prediction) if phenotype_prediction > 0 else ""),
("_RuleBasedPred={}".format(rule_based_prediction) if rule_based_prediction > 0 else ""),
("_loss={}".format(loss) if variational else ""), ("_loss={}".format(loss) if variational else ""),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""), ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
("_k={}".format(logparam["k"]) if logparam is not None else ""), ("_k={}".format(logparam["k"]) if logparam is not None else ""),
...@@ -295,8 +297,9 @@ def autoencoder_fitting( ...@@ -295,8 +297,9 @@ def autoencoder_fitting(
batch_size=batch_size, batch_size=batch_size,
cp=save_checkpoints, cp=save_checkpoints,
variational=variational, variational=variational,
phenotype_class=phenotype_prediction, next_sequence_prediction=next_sequence_prediction,
predictor=next_sequence_prediction, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss, loss=loss,
entropy_samples=entropy_samples, entropy_samples=entropy_samples,
entropy_knn=entropy_knn, entropy_knn=entropy_knn,
......
...@@ -42,24 +42,27 @@ def test_load_treatments(): ...@@ -42,24 +42,27 @@ def test_load_treatments():
), ),
batch_size=st.integers(min_value=128, max_value=512), batch_size=st.integers(min_value=128, max_value=512),
loss=st.one_of(st.just("test_A"), st.just("test_B")), loss=st.one_of(st.just("test_A"), st.just("test_B")),
predictor=st.floats(min_value=0.0, max_value=1.0), next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
pheno_class=st.floats(min_value=0.0, max_value=1.0), phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
variational=st.booleans(), variational=st.booleans(),
) )
def test_get_callbacks( def test_get_callbacks(
X_train, X_train,
batch_size, batch_size,
variational, variational,
predictor, next_sequence_prediction,
pheno_class, phenotype_prediction,
rule_based_prediction,
loss, loss,
): ):
callbacks = deepof.train_utils.get_callbacks( callbacks = deepof.train_utils.get_callbacks(
X_train=X_train, X_train=X_train,
batch_size=batch_size, batch_size=batch_size,
variational=variational, variational=variational,
phenotype_class=pheno_class, next_sequence_prediction=next_sequence_prediction,
predictor=predictor, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss, loss=loss,
X_val=X_train, X_val=X_train,
cp=True, cp=True,
...@@ -174,8 +177,9 @@ def test_tune_search( ...@@ -174,8 +177,9 @@ def test_tune_search(
X_train=X_train, X_train=X_train,
batch_size=batch_size, batch_size=batch_size,
variational=(hypermodel == "S2SGMVAE"), variational=(hypermodel == "S2SGMVAE"),
phenotype_class=0, next_sequence_prediction=next_sequence_prediction,
predictor=predictor, phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss, loss=loss,
X_val=X_train, X_val=X_train,
cp=False, cp=False,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment