Skip to content
Snippets Groups Projects
Commit 5693df25 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Fixed deepof.train_utils.tune_search

parent 6a060a56
Branches
No related tags found
No related merge requests found
Pipeline #103648 passed
...@@ -15,7 +15,7 @@ from typing import Tuple, Union, Any, List ...@@ -15,7 +15,7 @@ from typing import Tuple, Union, Any, List
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from kerastuner import BayesianOptimization, Hyperband from kerastuner import BayesianOptimization, Hyperband, Objective
from kerastuner_tensorboard_logger import TensorBoardLogger from kerastuner_tensorboard_logger import TensorBoardLogger
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
from tensorboard.plugins.hparams import api as hp from tensorboard.plugins.hparams import api as hp
...@@ -562,7 +562,7 @@ def tune_search( ...@@ -562,7 +562,7 @@ def tune_search(
"logger": TensorBoardLogger( "logger": TensorBoardLogger(
metrics=[tuner_objective], logdir=os.path.join(outpath, "logged_hparams") metrics=[tuner_objective], logdir=os.path.join(outpath, "logged_hparams")
), ),
"objective": tuner_objective, "objective": Objective(tuner_objective, direction="min"),
"project_name": project_name, "project_name": project_name,
"tune_new_entries": True, "tune_new_entries": True,
} }
...@@ -588,8 +588,8 @@ def tune_search( ...@@ -588,8 +588,8 @@ def tune_search(
print(tuner.search_space_summary()) print(tuner.search_space_summary())
Xs, ys = [X_train], [X_train] Xs, ys = X_train, [X_train]
Xvals, yvals = [X_val], [X_val] Xvals, yvals = X_val, [X_val]
if next_sequence_prediction > 0.0: if next_sequence_prediction > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]] Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
...@@ -607,6 +607,7 @@ def tune_search( ...@@ -607,6 +607,7 @@ def tune_search(
ys += [y_train[-Xs.shape[0] :]] ys += [y_train[-Xs.shape[0] :]]
yvals += [y_val[-Xvals.shape[0] :]] yvals += [y_val[-Xvals.shape[0] :]]
# Convert data to tf.data.Dataset objects
tuner.search( tuner.search(
Xs, Xs,
ys, ys,
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -133,7 +133,7 @@ def test_autoencoder_fitting( ...@@ -133,7 +133,7 @@ def test_autoencoder_fitting(
) )
@settings(max_examples=15, deadline=None) @settings(max_examples=5, deadline=None)
@given( @given(
X_train=arrays( X_train=arrays(
dtype=float, dtype=float,
...@@ -147,6 +147,7 @@ def test_autoencoder_fitting( ...@@ -147,6 +147,7 @@ def test_autoencoder_fitting(
max_value=1, max_value=1,
), ),
), ),
y_train=st.data(),
batch_size=st.integers(min_value=128, max_value=512), batch_size=st.integers(min_value=128, max_value=512),
encoding_size=st.integers(min_value=1, max_value=16), encoding_size=st.integers(min_value=1, max_value=16),
hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")), hpt_type=st.one_of(st.just("bayopt"), st.just("hyperband")),
...@@ -159,6 +160,7 @@ def test_autoencoder_fitting( ...@@ -159,6 +160,7 @@ def test_autoencoder_fitting(
) )
def test_tune_search( def test_tune_search(
X_train, X_train,
y_train,
batch_size, batch_size,
encoding_size, encoding_size,
hpt_type, hpt_type,
...@@ -173,9 +175,9 @@ def test_tune_search( ...@@ -173,9 +175,9 @@ def test_tune_search(
deepof.train_utils.get_callbacks( deepof.train_utils.get_callbacks(
X_train=X_train, X_train=X_train,
batch_size=batch_size, batch_size=batch_size,
phenotype_prediction=phenotype_prediction, phenotype_prediction=np.round(phenotype_prediction, 2),
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=np.round(next_sequence_prediction, 2),
rule_based_prediction=rule_based_prediction, rule_based_prediction=np.round(rule_based_prediction, 2),
loss=loss, loss=loss,
X_val=X_train, X_val=X_train,
input_type=False, input_type=False,
...@@ -189,7 +191,13 @@ def test_tune_search( ...@@ -189,7 +191,13 @@ def test_tune_search(
) )
)[1:] )[1:]
y_train = tf.random.uniform(shape=(X_train.shape[1], 1), maxval=1.0) y_train = y_train.draw(
arrays(
dtype=np.float32,
elements=st.floats(min_value=0.0, max_value=1.0, width=32),
shape=(X_train.shape[1], 1),
)
)
deepof.train_utils.tune_search( deepof.train_utils.tune_search(
data=[X_train, y_train, X_train, y_train], data=[X_train, y_train, X_train, y_train],
...@@ -201,9 +209,9 @@ def test_tune_search( ...@@ -201,9 +209,9 @@ def test_tune_search(
loss=loss, loss=loss,
mmd_warmup_epochs=0, mmd_warmup_epochs=0,
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=np.round(next_sequence_prediction, 2),
phenotype_prediction=phenotype_prediction, phenotype_prediction=np.round(phenotype_prediction, 2),
rule_based_prediction=rule_based_prediction, rule_based_prediction=np.round(rule_based_prediction, 2),
project_name="test_run", project_name="test_run",
callbacks=callbacks, callbacks=callbacks,
n_epochs=1, n_epochs=1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment