Commit 96d787f5 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 4ad26649
......@@ -397,8 +397,8 @@ else:
variational=variational,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
phenotype_class=pheno_class,
predictor=predictor,
phenotype_prediction=pheno_class,
next_sequence_prediction=predictor,
loss=loss,
logparam=logparam,
outpath=output_path,
......
......@@ -69,8 +69,9 @@ def get_callbacks(
X_train: np.array,
batch_size: int,
variational: bool,
phenotype_class: float,
predictor: float,
phenotype_prediction: float,
next_sequence_prediction: float,
rule_based_prediction: float,
loss: str,
X_val: np.array = None,
cp: bool = False,
......@@ -97,8 +98,9 @@ def get_callbacks(
run_ID = "{}{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
("_Pred={}".format(predictor) if predictor > 0 and variational else ""),
("_Pheno={}".format(phenotype_class) if phenotype_class > 0 else ""),
("_NextSeqPred={}".format(next_sequence_prediction) if next_sequence_prediction > 0 and variational else ""),
("_PhenoPred={}".format(phenotype_prediction) if phenotype_prediction > 0 else ""),
("_RuleBasedPred={}".format(rule_based_prediction) if rule_based_prediction > 0 else ""),
("_loss={}".format(loss) if variational else ""),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
("_k={}".format(logparam["k"]) if logparam is not None else ""),
......@@ -295,8 +297,9 @@ def autoencoder_fitting(
batch_size=batch_size,
cp=save_checkpoints,
variational=variational,
phenotype_class=phenotype_prediction,
predictor=next_sequence_prediction,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
......
......@@ -42,24 +42,27 @@ def test_load_treatments():
),
batch_size=st.integers(min_value=128, max_value=512),
loss=st.one_of(st.just("test_A"), st.just("test_B")),
predictor=st.floats(min_value=0.0, max_value=1.0),
pheno_class=st.floats(min_value=0.0, max_value=1.0),
next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0),
phenotype_prediction=st.floats(min_value=0.0, max_value=1.0),
rule_based_prediction=st.floats(min_value=0.0, max_value=1.0),
variational=st.booleans(),
)
def test_get_callbacks(
X_train,
batch_size,
variational,
predictor,
pheno_class,
next_sequence_prediction,
phenotype_prediction,
rule_based_prediction,
loss,
):
callbacks = deepof.train_utils.get_callbacks(
X_train=X_train,
batch_size=batch_size,
variational=variational,
phenotype_class=pheno_class,
predictor=predictor,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss,
X_val=X_train,
cp=True,
......@@ -174,8 +177,9 @@ def test_tune_search(
X_train=X_train,
batch_size=batch_size,
variational=(hypermodel == "S2SGMVAE"),
phenotype_class=0,
predictor=predictor,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss,
X_val=X_train,
cp=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment