Commit bbe0b960 authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed deepof.train_utils.tune_search

parent 5693df25
Pipeline #103680 canceled with stages
in 3 minutes and 35 seconds
......@@ -607,12 +607,22 @@ def tune_search(
ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0] :]]
# Convert data to tf.data.Dataset objects
train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
.batch(batch_size, drop_remainder=True)
.shuffle(buffer_size=X_train.shape[0])
)
val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
.batch(batch_size, drop_remainder=True)
)
# Convert data to tf.data.Dataset objects
tuner.search(
Xs,
ys,
train_dataset,
epochs=n_epochs,
validation_data=(Xvals, yvals),
validation_data=val_dataset,
verbose=1,
batch_size=batch_size,
callbacks=callbacks,
......
......@@ -8,7 +8,7 @@ Testing module for deepof.train_utils
"""
from hypothesis import given
from hypothesis import given, reproduce_failure
from hypothesis import HealthCheck
from hypothesis import settings
from hypothesis import strategies as st
......@@ -138,8 +138,8 @@ def test_autoencoder_fitting(
X_train=arrays(
dtype=float,
shape=st.tuples(
st.integers(min_value=10, max_value=100),
st.integers(min_value=10, max_value=15),
st.integers(min_value=128, max_value=512),
st.integers(min_value=128, max_value=512),
st.integers(min_value=2, max_value=10),
),
elements=st.floats(
......@@ -148,7 +148,7 @@ def test_autoencoder_fitting(
),
),
y_train=st.data(),
batch_size=st.integers(min_value=128, max_value=512),
batch_size=st.just(128),
encoding_size=st.integers(min_value=1, max_value=16),
hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")),
k=st.integers(min_value=1, max_value=10),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment