Commit d207cd1d authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated train_model.py to be compatible with phenotype classification

parent 81cd01a5
...@@ -411,8 +411,7 @@ else: ...@@ -411,8 +411,7 @@ else:
) )
best_hyperparameters, best_model = tune_search( best_hyperparameters, best_model = tune_search(
X_train, data=[X_train, y_train, X_val, y_val],
X_val,
bayopt_trials=bayopt_trials, bayopt_trials=bayopt_trials,
hypermodel=hyp, hypermodel=hyp,
k=k, k=k,
......
...@@ -97,8 +97,7 @@ def get_callbacks( ...@@ -97,8 +97,7 @@ def get_callbacks(
def tune_search( def tune_search(
train: np.array, data: List[np.array],
test: np.array,
bayopt_trials: int, bayopt_trials: int,
hypermodel: str, hypermodel: str,
k: int, k: int,
...@@ -139,12 +138,14 @@ def tune_search( ...@@ -139,12 +138,14 @@ def tune_search(
""" """
X_train, y_train, X_val, y_val = data
if hypermodel == "S2SAE": # pragma: no cover if hypermodel == "S2SAE": # pragma: no cover
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape) hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=X_train.shape)
elif hypermodel == "S2SGMVAE": elif hypermodel == "S2SGMVAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE( hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
input_shape=train.shape, input_shape=X_train.shape,
loss=loss, loss=loss,
number_of_components=k, number_of_components=k,
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
...@@ -168,13 +169,22 @@ def tune_search( ...@@ -168,13 +169,22 @@ def tune_search(
print(tuner.search_space_summary()) print(tuner.search_space_summary())
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
if predictor > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if pheno_class > 0.0:
ys += [y_train]
yvals += [y_val]
tuner.search( tuner.search(
train if predictor == 0 else [train[:-1]], Xs,
train if predictor == 0 else [train[:-1], train[1:]], ys,
epochs=n_epochs, epochs=n_epochs,
validation_data=( validation_data=(Xvals, yvals),
(test, test) if predictor == 0 else (test[:-1], [test[:-1], test[1:]])
),
verbose=1, verbose=1,
batch_size=256, batch_size=256,
callbacks=callbacks, callbacks=callbacks,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment