From 96d787f5a8cf13d7cf650a3fe88bdfc16f548b51 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Tue, 13 Apr 2021 20:40:21 +0200 Subject: [PATCH] Added extra branch to main autoencoder for rule_based prediction --- deepof/train_model.py | 4 ++-- deepof/train_utils.py | 15 +++++++++------ tests/test_train_utils.py | 20 ++++++++++++-------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/deepof/train_model.py b/deepof/train_model.py index b32449b4..e7fe0010 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -397,8 +397,8 @@ else: variational=variational, entropy_samples=entropy_samples, entropy_knn=entropy_knn, - phenotype_class=pheno_class, - predictor=predictor, + phenotype_prediction=pheno_class, + next_sequence_prediction=predictor, loss=loss, logparam=logparam, outpath=output_path, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 951e0cf6..1230498a 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -69,8 +69,9 @@ def get_callbacks( X_train: np.array, batch_size: int, variational: bool, - phenotype_class: float, - predictor: float, + phenotype_prediction: float, + next_sequence_prediction: float, + rule_based_prediction: float, loss: str, X_val: np.array = None, cp: bool = False, @@ -97,8 +98,9 @@ def get_callbacks( run_ID = "{}{}{}{}{}{}{}_{}".format( ("GMVAE" if variational else "AE"), - ("_Pred={}".format(predictor) if predictor > 0 and variational else ""), - ("_Pheno={}".format(phenotype_class) if phenotype_class > 0 else ""), + ("_NextSeqPred={}".format(next_sequence_prediction) if next_sequence_prediction > 0 and variational else ""), + ("_PhenoPred={}".format(phenotype_prediction) if phenotype_prediction > 0 else ""), + ("_RuleBasedPred={}".format(rule_based_prediction) if rule_based_prediction > 0 else ""), ("_loss={}".format(loss) if variational else ""), ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""), ("_k={}".format(logparam["k"]) if logparam is not None else ""), @@ -295,8 +297,9 @@ def autoencoder_fitting( batch_size=batch_size, cp=save_checkpoints, variational=variational, - phenotype_class=phenotype_prediction, - predictor=next_sequence_prediction, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, loss=loss, entropy_samples=entropy_samples, entropy_knn=entropy_knn, diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index b30d3f67..47c34232 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -42,24 +42,27 @@ def test_load_treatments(): ), batch_size=st.integers(min_value=128, max_value=512), loss=st.one_of(st.just("test_A"), st.just("test_B")), - predictor=st.floats(min_value=0.0, max_value=1.0), - pheno_class=st.floats(min_value=0.0, max_value=1.0), + next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0), + phenotype_prediction=st.floats(min_value=0.0, max_value=1.0), + rule_based_prediction=st.floats(min_value=0.0, max_value=1.0), variational=st.booleans(), ) def test_get_callbacks( X_train, batch_size, variational, - predictor, - pheno_class, + next_sequence_prediction, + phenotype_prediction, + rule_based_prediction, loss, ): callbacks = deepof.train_utils.get_callbacks( X_train=X_train, batch_size=batch_size, variational=variational, - phenotype_class=pheno_class, - predictor=predictor, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, cp=True, @@ -174,8 +177,9 @@ def test_tune_search( X_train=X_train, batch_size=batch_size, variational=(hypermodel == "S2SGMVAE"), - phenotype_class=0, - predictor=predictor, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, cp=False, -- GitLab