diff --git a/deepof/train_utils.py b/deepof/train_utils.py index ca9245983795fbfeb7780aaccf8cee3d68016d4a..189c5c8df23db24a731a3023807e0f2ddef93245 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -116,7 +116,7 @@ def tune_search( predictor: float, project_name: str, callbacks: List, - n_epochs: int = 30, + n_epochs: int = 40, n_replicas: int = 1, ) -> Union[bool, Tuple[Any, Any]]: """Define the search space using keras-tuner and bayesian optimization diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index 68a6fa155cfc8add42c6be9d1c4b5f54ca1fe8bf..ed3c4f57439e013b2212c69e5c2fb541c3b63564 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -16,12 +16,8 @@ import deepof.model_utils import deepof.train_utils import keras import os +import pytest import tensorflow as tf -from tensorflow.python.framework.ops import EagerTensor - -# For coverage.py to work with @tf.function decorated functions and methods, -# graph execution is disabled when running this script with pytest -tf.config.experimental_run_functions_eagerly(True) def test_load_hparams(): @@ -50,7 +46,9 @@ def test_load_treatments(): @given( X_train=arrays( - shape=st.tuples(st.integers(min_value=1, max_value=1000)), dtype=float + shape=st.tuples(st.integers(min_value=1, max_value=1000)), + dtype=float, + elements=st.floats(min_value=0.0, max_value=1,), ), batch_size=st.integers(min_value=128, max_value=512), k=st.integers(min_value=1, max_value=50), @@ -72,7 +70,7 @@ def test_get_callbacks( assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler -@settings(max_examples=1) +@settings(max_examples=1, deadline=None) @given( train=arrays( dtype=float, @@ -81,14 +79,7 @@ def test_get_callbacks( st.integers(min_value=2, max_value=15), st.integers(min_value=2, max_value=10), ), - ), - test=arrays( - dtype=float, - shape=st.tuples( - st.integers(min_value=10, max_value=100), - st.integers(min_value=2, max_value=15), - st.integers(min_value=2, max_value=10), - ), + elements=st.floats(min_value=0.0, max_value=1,), ), batch_size=st.integers(min_value=128, max_value=512), hypermodel=st.one_of(st.just("S2SAE"), st.just("S2SGMVAE")), @@ -98,11 +89,9 @@ def test_get_callbacks( mmd_wu=st.integers(min_value=0, max_value=10), overlap_loss=st.floats(min_value=0.0, max_value=1.0), predictor=st.floats(min_value=0.0, max_value=1.0), - project_name=st.text(min_size=5), ) def test_tune_search( train, - test, batch_size, hypermodel, k, @@ -111,7 +100,6 @@ def test_tune_search( mmd_wu, overlap_loss, predictor, - project_name, ): callbacks = list( deepof.train_utils.get_callbacks( @@ -124,11 +112,11 @@ def test_tune_search( kl_wu, mmd_wu, ) - ) + )[1:] deepof.train_utils.tune_search( train, - test, + train, 1, hypermodel, k, @@ -137,7 +125,7 @@ def test_tune_search( mmd_wu, overlap_loss, predictor, - project_name, + "test_run", callbacks, n_epochs=1, )