diff --git a/deepof/train_model.py b/deepof/train_model.py index b32449b464450f5848579ef944c77c38344a5c4f..e7fe00105c607548593118b0f41dcf425ce6e218 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -397,8 +397,8 @@ else: variational=variational, entropy_samples=entropy_samples, entropy_knn=entropy_knn, - phenotype_class=pheno_class, - predictor=predictor, + phenotype_prediction=pheno_class, + next_sequence_prediction=predictor, loss=loss, logparam=logparam, outpath=output_path, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 951e0cf6d90f8965b05c97c2f4bbb9b885570080..1230498a660299c17677b04b3f7e42b1a5f744e5 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -69,8 +69,9 @@ def get_callbacks( X_train: np.array, batch_size: int, variational: bool, - phenotype_class: float, - predictor: float, + phenotype_prediction: float, + next_sequence_prediction: float, + rule_based_prediction: float, loss: str, X_val: np.array = None, cp: bool = False, @@ -97,8 +98,9 @@ def get_callbacks( run_ID = "{}{}{}{}{}{}{}_{}".format( ("GMVAE" if variational else "AE"), - ("_Pred={}".format(predictor) if predictor > 0 and variational else ""), - ("_Pheno={}".format(phenotype_class) if phenotype_class > 0 else ""), + ("_NextSeqPred={}".format(next_sequence_prediction) if next_sequence_prediction > 0 and variational else ""), + ("_PhenoPred={}".format(phenotype_prediction) if phenotype_prediction > 0 else ""), + ("_RuleBasedPred={}".format(rule_based_prediction) if rule_based_prediction > 0 else ""), ("_loss={}".format(loss) if variational else ""), ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""), ("_k={}".format(logparam["k"]) if logparam is not None else ""), @@ -295,8 +297,9 @@ def autoencoder_fitting( batch_size=batch_size, cp=save_checkpoints, variational=variational, - phenotype_class=phenotype_prediction, - predictor=next_sequence_prediction, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, loss=loss, entropy_samples=entropy_samples, entropy_knn=entropy_knn, diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index b30d3f67c8a75a3cdc7e45392516c3668ff79866..47c3423294ff704c11cf4d8a11bbf28fb05854b0 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -42,24 +42,27 @@ def test_load_treatments(): ), batch_size=st.integers(min_value=128, max_value=512), loss=st.one_of(st.just("test_A"), st.just("test_B")), - predictor=st.floats(min_value=0.0, max_value=1.0), - pheno_class=st.floats(min_value=0.0, max_value=1.0), + next_sequence_prediction=st.floats(min_value=0.0, max_value=1.0), + phenotype_prediction=st.floats(min_value=0.0, max_value=1.0), + rule_based_prediction=st.floats(min_value=0.0, max_value=1.0), variational=st.booleans(), ) def test_get_callbacks( X_train, batch_size, variational, - predictor, - pheno_class, + next_sequence_prediction, + phenotype_prediction, + rule_based_prediction, loss, ): callbacks = deepof.train_utils.get_callbacks( X_train=X_train, batch_size=batch_size, variational=variational, - phenotype_class=pheno_class, - predictor=predictor, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, cp=True, @@ -174,8 +177,9 @@ def test_tune_search( X_train=X_train, batch_size=batch_size, variational=(hypermodel == "S2SGMVAE"), - phenotype_class=0, - predictor=predictor, + next_sequence_prediction=next_sequence_prediction, + phenotype_prediction=phenotype_prediction, + rule_based_prediction=rule_based_prediction, loss=loss, X_val=X_train, cp=False,