-
Lucas Miranda authoredLucas Miranda authored
test_train_utils.py 5.20 KiB
# @author lucasmiranda42
# encoding: utf-8
# module deepof
"""
Testing module for deepof.train_utils
"""
from hypothesis import given
from hypothesis import HealthCheck
from hypothesis import settings
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
import deepof.data
import deepof.model_utils
import deepof.train_utils
import numpy as np
import os
import tensorflow as tf
def test_load_hparams():
assert type(deepof.train_utils.load_hparams(None)) == dict
assert (
type(
deepof.train_utils.load_hparams(
os.path.join(
"tests",
"test_examples",
"test_single_topview",
"Others",
"test_hparams.pkl",
)
)
)
== dict
)
def test_load_treatments():
assert deepof.train_utils.load_treatments(".") is None
assert (
type(
deepof.train_utils.load_treatments(
os.path.join("tests", "test_examples", "test_single_topview", "Others")
)
)
== dict
)
@given(
X_train=arrays(
shape=st.tuples(st.integers(min_value=1, max_value=1000)),
dtype=float,
elements=st.floats(
min_value=0.0,
max_value=1,
),
),
batch_size=st.integers(min_value=128, max_value=512),
loss=st.one_of(st.just("test_A"), st.just("test_B")),
predictor=st.floats(min_value=0.0, max_value=1.0),
pheno_class=st.floats(min_value=0.0, max_value=1.0),
variational=st.booleans(),
)
def test_get_callbacks(
X_train,
batch_size,
variational,
predictor,
pheno_class,
loss,
):
runID, tbc, cycle1c, cpc = deepof.train_utils.get_callbacks(
X_train,
batch_size,
True,
variational,
pheno_class,
predictor,
loss,
True,
True,
None,
)
assert type(runID) == str
assert type(tbc) == tf.keras.callbacks.TensorBoard
assert type(cpc) == tf.keras.callbacks.ModelCheckpoint
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
@settings(max_examples=10, deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
loss=st.one_of(st.just("ELBO"), st.just("MMD"), st.just("ELBO+MMD")),
pheno_class=st.one_of(st.just(1.0), st.just(0.0)),
predictor=st.one_of(st.just(1.0), st.just(0.0)),
variational=st.one_of(st.just(True), st.just(False)),
)
def test_autoencoder_fitting(
loss,
pheno_class,
predictor,
variational,
):
X_train = np.random.uniform(-1, 1, [20, 5, 6])
y_train = np.round(np.random.uniform(0, 1, 20))
if predictor:
y_train = y_train[1:]
preprocessed_data = (X_train, y_train, X_train, y_train)
prun = deepof.data.project(
path=os.path.join(".", "tests", "test_examples", "test_single_topview"),
arena="circular",
arena_dims=tuple([380]),
video_format=".mp4",
).run()
prun.deep_unsupervised_embedding(
preprocessed_data,
batch_size=100,
encoding_size=2,
epochs=1,
kl_warmup=1,
log_history=True,
log_hparams=True,
mmd_warmup=1,
n_components=2,
loss=loss,
phenotype_class=pheno_class,
predictor=predictor,
variational=variational,
)
@settings(max_examples=1, deadline=None)
@given(
X_train=arrays(
dtype=float,
shape=st.tuples(
st.integers(min_value=10, max_value=100),
st.integers(min_value=2, max_value=15),
st.integers(min_value=2, max_value=10),
),
elements=st.floats(
min_value=0.0,
max_value=1,
),
),
batch_size=st.integers(min_value=128, max_value=512),
encoding_size=st.integers(min_value=1, max_value=16),
hpt_type=st.one_of(st.just("bayopt"), st.just("hypermodel")),
hypermodel=st.just("S2SGMVAE"),
k=st.integers(min_value=1, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD")),
overlap_loss=st.floats(min_value=0.0, max_value=1.0),
pheno_class=st.floats(min_value=0.0, max_value=1.0),
predictor=st.floats(min_value=0.0, max_value=1.0),
)
def test_tune_search(
X_train,
batch_size,
encoding_size,
hpt_type,
hypermodel,
k,
loss,
overlap_loss,
pheno_class,
predictor,
):
callbacks = list(
deepof.train_utils.get_callbacks(
X_train,
batch_size,
False,
hypermodel == "S2SGMVAE",
0,
predictor,
loss,
True,
True,
None,
)
)[1:]
y_train = tf.random.uniform(shape=(X_train.shape[1],), maxval=1.0)
deepof.train_utils.tune_search(
data=[X_train, y_train, X_train, y_train],
encoding_size=encoding_size,
hpt_type=hpt_type,
hypertun_trials=1,
hypermodel=hypermodel,
k=k,
kl_warmup_epochs=0,
loss=loss,
mmd_warmup_epochs=0,
overlap_loss=overlap_loss,
phenotype_class=pheno_class,
predictor=predictor,
project_name="test_run",
callbacks=callbacks,
n_epochs=1,
)