Commit 5693df25 authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed deepof.train_utils.tune_search

parent 6a060a56
Pipeline #103648 passed with stages
in 21 minutes and 36 seconds
...@@ -15,7 +15,7 @@ from typing import Tuple, Union, Any, List ...@@ -15,7 +15,7 @@ from typing import Tuple, Union, Any, List
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from kerastuner import BayesianOptimization, Hyperband from kerastuner import BayesianOptimization, Hyperband, Objective
from kerastuner_tensorboard_logger import TensorBoardLogger from kerastuner_tensorboard_logger import TensorBoardLogger
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
from tensorboard.plugins.hparams import api as hp from tensorboard.plugins.hparams import api as hp
...@@ -562,7 +562,7 @@ def tune_search( ...@@ -562,7 +562,7 @@ def tune_search(
"logger": TensorBoardLogger( "logger": TensorBoardLogger(
metrics=[tuner_objective], logdir=os.path.join(outpath, "logged_hparams") metrics=[tuner_objective], logdir=os.path.join(outpath, "logged_hparams")
), ),
"objective": tuner_objective, "objective": Objective(tuner_objective, direction="min"),
"project_name": project_name, "project_name": project_name,
"tune_new_entries": True, "tune_new_entries": True,
} }
...@@ -588,8 +588,8 @@ def tune_search( ...@@ -588,8 +588,8 @@ def tune_search(
print(tuner.search_space_summary()) print(tuner.search_space_summary())
Xs, ys = [X_train], [X_train] Xs, ys = X_train, [X_train]
Xvals, yvals = [X_val], [X_val] Xvals, yvals = X_val, [X_val]
if next_sequence_prediction > 0.0: if next_sequence_prediction > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]] Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
...@@ -607,6 +607,7 @@ def tune_search( ...@@ -607,6 +607,7 @@ def tune_search(
ys += [y_train[-Xs.shape[0] :]] ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0] :]] yvals += [y_val[-Xvals.shape[0] :]]
# Convert data to tf.data.Dataset objects
tuner.search( tuner.search(
Xs, Xs,
ys, ys,
......
...@@ -133,7 +133,7 @@ def test_autoencoder_fitting( ...@@ -133,7 +133,7 @@ def test_autoencoder_fitting(
) )
@settings(max_examples=15, deadline=None) @settings(max_examples=5, deadline=None)
@given( @given(
X_train=arrays( X_train=arrays(
dtype=float, dtype=float,
...@@ -147,6 +147,7 @@ def test_autoencoder_fitting( ...@@ -147,6 +147,7 @@ def test_autoencoder_fitting(
max_value=1, max_value=1,
), ),
), ),
y_train=st.data(),
batch_size=st.integers(min_value=128, max_value=512), batch_size=st.integers(min_value=128, max_value=512),
encoding_size=st.integers(min_value=1, max_value=16), encoding_size=st.integers(min_value=1, max_value=16),
hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")), hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")),
...@@ -159,6 +160,7 @@ def test_autoencoder_fitting( ...@@ -159,6 +160,7 @@ def test_autoencoder_fitting(
) )
def test_tune_search( def test_tune_search(
X_train, X_train,
y_train,
batch_size, batch_size,
encoding_size, encoding_size,
hpt_type, hpt_type,
...@@ -173,9 +175,9 @@ def test_tune_search( ...@@ -173,9 +175,9 @@ def test_tune_search(
deepof.train_utils.get_callbacks( deepof.train_utils.get_callbacks(
X_train=X_train, X_train=X_train,
batch_size=batch_size, batch_size=batch_size,
phenotype_prediction=phenotype_prediction, phenotype_prediction=np.round(phenotype_prediction, 2),
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=np.round(next_sequence_prediction, 2),
rule_based_prediction=rule_based_prediction, rule_based_prediction=np.round(rule_based_prediction, 2),
loss=loss, loss=loss,
X_val=X_train, X_val=X_train,
input_type=False, input_type=False,
...@@ -189,7 +191,13 @@ def test_tune_search( ...@@ -189,7 +191,13 @@ def test_tune_search(
) )
)[1:] )[1:]
y_train = tf.random.uniform(shape=(X_train.shape[1], 1), maxval=1.0) y_train = y_train.draw(
arrays(
dtype=np.float32,
elements=st.floats(min_value=0.0, max_value=1.0, width=32),
shape=(X_train.shape[1], 1),
)
)
deepof.train_utils.tune_search( deepof.train_utils.tune_search(
data=[X_train, y_train, X_train, y_train], data=[X_train, y_train, X_train, y_train],
...@@ -201,9 +209,9 @@ def test_tune_search( ...@@ -201,9 +209,9 @@ def test_tune_search(
loss=loss, loss=loss,
mmd_warmup_epochs=0, mmd_warmup_epochs=0,
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=np.round(next_sequence_prediction, 2),
phenotype_prediction=phenotype_prediction, phenotype_prediction=np.round(phenotype_prediction, 2),
rule_based_prediction=rule_based_prediction, rule_based_prediction=np.round(rule_based_prediction, 2),
project_name="test_run", project_name="test_run",
callbacks=callbacks, callbacks=callbacks,
n_epochs=1, n_epochs=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment