Commit 7c7e6d62 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated train_model.py to be compatible with phenotype classification

parent 53bb1046
......@@ -81,26 +81,30 @@ def test_get_callbacks(
k=st.integers(min_value=1, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD")),
overlap_loss=st.floats(min_value=0.0, max_value=1.0),
pheno_class=st.floats(min_value=0.0, max_value=1.0),
predictor=st.floats(min_value=0.0, max_value=1.0),
)
def test_tune_search(
train, batch_size, hypermodel, k, loss, overlap_loss, predictor,
X_train, batch_size, hypermodel, k, loss, overlap_loss, pheno_class, predictor,
):
callbacks = list(
deepof.train_utils.get_callbacks(
train, batch_size, hypermodel == "S2SGMVAE", predictor, loss,
X_train, batch_size, hypermodel == "S2SGMVAE", predictor, loss,
)
)[1:]
y_train = tf.random.uniform(shape=(X_train.shape[0],), maxval=1.0)
deepof.train_utils.tune_search(
data,
1,
hypermodel,
k,
loss,
overlap_loss,
predictor,
"test_run",
callbacks,
data=[X_train, y_train, X_train, y_train],
bayopt_trials=1,
hypermodel=hypermodel,
k=k,
loss=loss,
overlap_loss=overlap_loss,
pheno_class=pheno_class,
predictor=predictor,
project_name="test_run",
callbacks=callbacks,
n_epochs=1,
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment