Commit bbe0b960 authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed deepof.train_utils.tune_search

parent 5693df25
Pipeline #103680 canceled with stages
in 3 minutes and 35 seconds
...@@ -607,12 +607,22 @@ def tune_search( ...@@ -607,12 +607,22 @@ def tune_search(
ys += [y_train[-Xs.shape[0] :]] ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0] :]] yvals += [y_val[-Xvals.shape[0] :]]
# Convert data to tf.data.Dataset objects
train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
.batch(batch_size, drop_remainder=True)
.shuffle(buffer_size=X_train.shape[0])
)
val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
.batch(batch_size, drop_remainder=True)
)
# Convert data to tf.data.Dataset objects # Convert data to tf.data.Dataset objects
tuner.search( tuner.search(
Xs, train_dataset,
ys,
epochs=n_epochs, epochs=n_epochs,
validation_data=(Xvals, yvals), validation_data=val_dataset,
verbose=1, verbose=1,
batch_size=batch_size, batch_size=batch_size,
callbacks=callbacks, callbacks=callbacks,
......
...@@ -8,7 +8,7 @@ Testing module for deepof.train_utils ...@@ -8,7 +8,7 @@ Testing module for deepof.train_utils
""" """
from hypothesis import given from hypothesis import given, reproduce_failure
from hypothesis import HealthCheck from hypothesis import HealthCheck
from hypothesis import settings from hypothesis import settings
from hypothesis import strategies as st from hypothesis import strategies as st
...@@ -138,8 +138,8 @@ def test_autoencoder_fitting( ...@@ -138,8 +138,8 @@ def test_autoencoder_fitting(
X_train=arrays( X_train=arrays(
dtype=float, dtype=float,
shape=st.tuples( shape=st.tuples(
st.integers(min_value=10, max_value=100), st.integers(min_value=128, max_value=512),
st.integers(min_value=10, max_value=15), st.integers(min_value=128, max_value=512),
st.integers(min_value=2, max_value=10), st.integers(min_value=2, max_value=10),
), ),
elements=st.floats( elements=st.floats(
...@@ -148,7 +148,7 @@ def test_autoencoder_fitting( ...@@ -148,7 +148,7 @@ def test_autoencoder_fitting(
), ),
), ),
y_train=st.data(), y_train=st.data(),
batch_size=st.integers(min_value=128, max_value=512), batch_size=st.just(128),
encoding_size=st.integers(min_value=1, max_value=16), encoding_size=st.integers(min_value=1, max_value=16),
hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")), hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")),
k=st.integers(min_value=1, max_value=10), k=st.integers(min_value=1, max_value=10),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment