diff --git a/deepof/train_model.py b/deepof/train_model.py index 2448108afdfbd6061daf8385bc6605a02d0a8a4e..122555f0f70ab9df004c5c6ee95735247a8f802e 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -411,8 +411,7 @@ else: ) best_hyperparameters, best_model = tune_search( - X_train, - X_val, + data=[X_train, y_train, X_val, y_val], bayopt_trials=bayopt_trials, hypermodel=hyp, k=k, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 000f52addb945b64b27d364327653320ac5a2b96..b545268f3061c32f12adbcdb0241920baa517c56 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -97,8 +97,7 @@ def get_callbacks( def tune_search( - train: np.array, - test: np.array, + data: List[np.array], bayopt_trials: int, hypermodel: str, k: int, @@ -139,12 +138,14 @@ def tune_search( """ + X_train, y_train, X_val, y_val = data + if hypermodel == "S2SAE": # pragma: no cover - hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape) + hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape) elif hypermodel == "S2SGMVAE": hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE( - input_shape=train.shape, + input_shape=X_train.shape, loss=loss, number_of_components=k, overlap_loss=overlap_loss, @@ -168,13 +169,22 @@ def tune_search( print(tuner.search_space_summary()) + Xs, ys = [X_train], [X_train] + Xvals, yvals = [X_val], [X_val] + + if predictor > 0.0: + Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]] + Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]] + + if pheno_class > 0.0: + ys += [y_train] + yvals += [y_val] + tuner.search( - train if predictor == 0 else [train[:-1]], - train if predictor == 0 else [train[:-1], train[1:]], + Xs, + ys, epochs=n_epochs, - validation_data=( - (test, test) if predictor == 0 else (test[:-1], [test[:-1], test[1:]]) - ), + validation_data=(Xvals, yvals), verbose=1, batch_size=256, callbacks=callbacks,