Commit 00694f7b authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored train_utils.py

parent 6137bb99
......@@ -105,7 +105,7 @@ def get_callbacks(
run_ID = "{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
("Pred={}".format(predictor) if predictor > 0 and variational else ""),
("_Pred={}".format(predictor) if predictor > 0 and variational else ""),
("_Pheno={}".format(phenotype_class) if phenotype_class > 0 else ""),
("_loss={}".format(loss) if variational else ""),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
......
......@@ -85,38 +85,20 @@ def test_get_callbacks(
@settings(max_examples=10, deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
X_train=arrays(
dtype=float,
shape=st.tuples(
st.integers(min_value=10, max_value=100),
st.integers(min_value=2, max_value=15),
st.integers(min_value=2, max_value=10),
),
elements=st.floats(
min_value=-1,
max_value=1,
),
),
batch_size=st.integers(min_value=128, max_value=512),
encoding_size=st.integers(min_value=2, max_value=16),
k=st.integers(min_value=1, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD"), st.just("ELBO+MMD")),
pheno_class=st.one_of(st.just(0.0), st.just(1.0)),
predictor=st.one_of(st.just(0.0), st.just(1.0)),
variational=st.booleans(),
pheno_class=st.one_of(st.just(1.0), st.just(0.0)),
predictor=st.one_of(st.just(1.0), st.just(0.0)),
variational=st.one_of(st.just(True), st.just(False)),
)
def test_autoencoder_fitting(
X_train,
batch_size,
encoding_size,
k,
loss,
pheno_class,
predictor,
variational,
):
y_train = np.round(np.random.uniform(0, 1, X_train.shape[0]))
X_train = np.random.uniform(-1, 1, [20, 5, 6])
y_train = np.round(np.random.uniform(0, 1, (20 if not predictor else 19)))
print(X_train.shape, y_train.shape)
preprocessed_data = (X_train, y_train, X_train, y_train)
prun = deepof.data.project(
......@@ -128,14 +110,14 @@ def test_autoencoder_fitting(
models = prun.deep_unsupervised_embedding(
preprocessed_data,
batch_size=batch_size,
encoding_size=encoding_size,
batch_size=100,
encoding_size=2,
epochs=1,
kl_warmup=10,
log_history=True,
log_hparams=True,
mmd_warmup=10,
n_components=k,
n_components=2,
loss=loss,
phenotype_class=pheno_class,
predictor=predictor,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment