Commit 18de9b3c authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored hyperparameter tuning in train_utils.py

parent 79d8ede5
......@@ -116,7 +116,7 @@ def tune_search(
predictor: float,
project_name: str,
callbacks: List,
n_epochs: int = 30,
n_epochs: int = 40,
n_replicas: int = 1,
) -> Union[bool, Tuple[Any, Any]]:
"""Define the search space using keras-tuner and bayesian optimization
......
......@@ -16,12 +16,8 @@ import deepof.model_utils
import deepof.train_utils
import keras
import os
import pytest
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor
# For coverage.py to work with @tf.function decorated functions and methods,
# graph execution is disabled when running this script with pytest
tf.config.experimental_run_functions_eagerly(True)
def test_load_hparams():
......@@ -50,7 +46,9 @@ def test_load_treatments():
@given(
X_train=arrays(
shape=st.tuples(st.integers(min_value=1, max_value=1000)), dtype=float
shape=st.tuples(st.integers(min_value=1, max_value=1000)),
dtype=float,
elements=st.floats(min_value=0.0, max_value=1,),
),
batch_size=st.integers(min_value=128, max_value=512),
k=st.integers(min_value=1, max_value=50),
......@@ -72,7 +70,7 @@ def test_get_callbacks(
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
@settings(max_examples=1)
@settings(max_examples=1, deadline=None)
@given(
train=arrays(
dtype=float,
......@@ -81,14 +79,7 @@ def test_get_callbacks(
st.integers(min_value=2, max_value=15),
st.integers(min_value=2, max_value=10),
),
),
test=arrays(
dtype=float,
shape=st.tuples(
st.integers(min_value=10, max_value=100),
st.integers(min_value=2, max_value=15),
st.integers(min_value=2, max_value=10),
),
elements=st.floats(min_value=0.0, max_value=1,),
),
batch_size=st.integers(min_value=128, max_value=512),
hypermodel=st.one_of(st.just("S2SAE"), st.just("S2SGMVAE")),
......@@ -98,11 +89,9 @@ def test_get_callbacks(
mmd_wu=st.integers(min_value=0, max_value=10),
overlap_loss=st.floats(min_value=0.0, max_value=1.0),
predictor=st.floats(min_value=0.0, max_value=1.0),
project_name=st.text(min_size=5),
)
def test_tune_search(
train,
test,
batch_size,
hypermodel,
k,
......@@ -111,7 +100,6 @@ def test_tune_search(
mmd_wu,
overlap_loss,
predictor,
project_name,
):
callbacks = list(
deepof.train_utils.get_callbacks(
......@@ -124,11 +112,11 @@ def test_tune_search(
kl_wu,
mmd_wu,
)
)
)[1:]
deepof.train_utils.tune_search(
train,
test,
train,
1,
hypermodel,
k,
......@@ -137,7 +125,7 @@ def test_tune_search(
mmd_wu,
overlap_loss,
predictor,
project_name,
"test_run",
callbacks,
n_epochs=1,
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment