Commit d207cd1d authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated train_model.py to be compatible with phenotype classification

parent 81cd01a5
......@@ -411,8 +411,7 @@ else:
)
best_hyperparameters, best_model = tune_search(
X_train,
X_val,
data=[X_train, y_train, X_val, y_val],
bayopt_trials=bayopt_trials,
hypermodel=hyp,
k=k,
......
......@@ -97,8 +97,7 @@ def get_callbacks(
def tune_search(
train: np.array,
test: np.array,
data: List[np.array],
bayopt_trials: int,
hypermodel: str,
k: int,
......@@ -139,12 +138,14 @@ def tune_search(
"""
X_train, y_train, X_val, y_val = data
if hypermodel == "S2SAE": # pragma: no cover
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
elif hypermodel == "S2SGMVAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
input_shape=train.shape,
input_shape=X_train.shape,
loss=loss,
number_of_components=k,
overlap_loss=overlap_loss,
......@@ -168,13 +169,22 @@ def tune_search(
print(tuner.search_space_summary())
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
if predictor > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if pheno_class > 0.0:
ys += [y_train]
yvals += [y_val]
tuner.search(
train if predictor == 0 else [train[:-1]],
train if predictor == 0 else [train[:-1], train[1:]],
Xs,
ys,
epochs=n_epochs,
validation_data=(
(test, test) if predictor == 0 else (test[:-1], [test[:-1], test[1:]])
),
validation_data=(Xvals, yvals),
verbose=1,
batch_size=256,
callbacks=callbacks,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment