Commit e588caef authored by lucas_miranda's avatar lucas_miranda
Browse files

Added support for tensorboard HParams while tuning hyperparameters

parent 060349e8
...@@ -12,8 +12,8 @@ usage: python -m examples.model_training -h ...@@ -12,8 +12,8 @@ usage: python -m examples.model_training -h
from deepof.data import * from deepof.data import *
from deepof.models import * from deepof.models import *
from deepof.utils import * from deepof.utils import *
from train_utils import *
from tensorflow import keras from tensorflow import keras
from train_utils import *
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Autoencoder training for DeepOF animal pose recognition" description="Autoencoder training for DeepOF animal pose recognition"
......
...@@ -10,6 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t ...@@ -10,6 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t
from datetime import datetime from datetime import datetime
from kerastuner import BayesianOptimization from kerastuner import BayesianOptimization
from kerastuner import HyperParameters
from kerastuner_tensorboard_logger import TensorBoardLogger from kerastuner_tensorboard_logger import TensorBoardLogger
from typing import Tuple, Union, Any, List from typing import Tuple, Union, Any, List
import deepof.hypermodels import deepof.hypermodels
...@@ -19,6 +20,8 @@ import os ...@@ -19,6 +20,8 @@ import os
import pickle import pickle
import tensorflow as tf import tensorflow as tf
hp = HyperParameters()
def load_hparams(hparams): def load_hparams(hparams):
"""Loads hyperparameters from a custom dictionary pickled on disc. """Loads hyperparameters from a custom dictionary pickled on disc.
...@@ -153,31 +156,34 @@ def tune_search( ...@@ -153,31 +156,34 @@ def tune_search(
elif hypermodel == "S2SGMVAE": elif hypermodel == "S2SGMVAE":
hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE( hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
input_shape=train.shape, input_shape=train.shape,
loss=loss,
number_of_components=k,
kl_warmup_epochs=kl_wu, kl_warmup_epochs=kl_wu,
loss=loss,
mmd_warmup_epochs=mmd_wu, mmd_warmup_epochs=mmd_wu,
predictor=predictor, number_of_components=k,
overlap_loss=overlap_loss, overlap_loss=overlap_loss,
predictor=predictor,
) )
try:
if "ELBO" in loss and kl_wu > 0: if "ELBO" in loss and kl_wu > 0:
callbacks.append(hypermodel.kl_warmup_callback) callbacks.append(hypermodel.kl_warmup_callback)
if "MMD" in loss and mmd_wu > 0: if "MMD" in loss and mmd_wu > 0:
callbacks.append(hypermodel.mmd_warmup_callback) callbacks.append(hypermodel.mmd_warmup_callback)
except AttributeError:
pass
else: else:
return False return False
tuner = BayesianOptimization( tuner = BayesianOptimization(
hypermodel, hypermodel,
max_trials=bayopt_trials, directory="BayesianOptx",
executions_per_trial=n_replicas, executions_per_trial=n_replicas,
logger=TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"),
max_trials=bayopt_trials,
objective="val_mae", objective="val_mae",
seed=42,
directory="BayesianOptx",
project_name=project_name, project_name=project_name,
logger=TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"), seed=42,
) )
print(tuner.search_space_summary()) print(tuner.search_space_summary())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment