From d207cd1d5689f8ff95e9395a24c13f4e17425d83 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Mon, 16 Nov 2020 22:14:31 +0100 Subject: [PATCH] Updated train_model.py to be compatible with phenotype classification --- deepof/train_model.py | 3 +-- deepof/train_utils.py | 28 +++++++++++++++++++--------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/deepof/train_model.py b/deepof/train_model.py index 2448108a..122555f0 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -411,8 +411,7 @@ else: ) best_hyperparameters, best_model = tune_search( - X_train, - X_val, + data=[X_train, y_train, X_val, y_val], bayopt_trials=bayopt_trials, hypermodel=hyp, k=k, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 000f52ad..b545268f 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -97,8 +97,7 @@ def get_callbacks( def tune_search( - train: np.array, - test: np.array, + data: List[np.array], bayopt_trials: int, hypermodel: str, k: int, @@ -139,12 +138,14 @@ def tune_search( """ + X_train, y_train, X_val, y_val = data + if hypermodel == "S2SAE": # pragma: no cover - hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape) + hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape) elif hypermodel == "S2SGMVAE": hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE( - input_shape=train.shape, + input_shape=X_train.shape, loss=loss, number_of_components=k, overlap_loss=overlap_loss, @@ -168,13 +169,22 @@ def tune_search( print(tuner.search_space_summary()) + Xs, ys = [X_train], [X_train] + Xvals, yvals = [X_val], [X_val] + + if predictor > 0.0: + Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]] + Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]] + + if pheno_class > 0.0: + ys += [y_train] + yvals += [y_val] + tuner.search( - train if predictor == 0 else [train[:-1]], - train if predictor == 0 else [train[:-1], train[1:]], + Xs, + ys, epochs=n_epochs, - validation_data=( - (test, test) if predictor == 0 else (test[:-1], [test[:-1], test[1:]]) - ), + validation_data=(Xvals, yvals), verbose=1, batch_size=256, callbacks=callbacks, -- GitLab