Commit 060349e8 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added support for tensorboard HParams while tuning hyperparameters

parent 0361a1bb
Pipeline #84057 passed with stage
in 29 minutes and 30 seconds
......@@ -24,10 +24,10 @@ class SEQ_2_SEQ_AE(HyperModel):
super().__init__()
self.input_shape = input_shape
def build(self, hp):
"""Overrides Hypermodel's build method"""
@staticmethod
def get_hparams(hp):
"""Retrieve hyperparameters to tune"""
# HYPERPARAMETERS TO TUNE
conv_filters = hp.Int(
"units_conv", min_value=32, max_value=256, step=32, default=256
)
......@@ -42,6 +42,16 @@ class SEQ_2_SEQ_AE(HyperModel):
)
encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24)
return conv_filters, lstm_units_1, dense_2, dropout_rate, encoding
def build(self, hp):
"""Overrides Hypermodel's build method"""
# HYPERPARAMETERS TO TUNE
conv_filters, lstm_units_1, dense_2, dropout_rate, encoding = self.get_hparams(
hp
)
# INSTANCIATED MODEL
model = deepof.models.SEQ_2_SEQ_AE(
architecture_hparams={
......@@ -92,10 +102,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
"ELBO" in self.loss or "MMD" in self.loss
), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
def build(self, hp):
"""Overrides Hypermodel's build method"""
@staticmethod
def get_hparams(hp):
"""Retrieve hyperparameters to tune"""
# Hyperparameters to tune
conv_filters = hp.Int(
"units_conv", min_value=32, max_value=256, step=32, default=256
)
......@@ -110,6 +120,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
)
encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24)
return conv_filters, lstm_units_1, dense_2, dropout_rate, encoding
def build(self, hp):
"""Overrides Hypermodel's build method"""
# Hyperparameters to tune
conv_filters, lstm_units_1, dense_2, dropout_rate, encoding = self.get_hparams(
hp
)
gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams={
"units_conv": conv_filters,
......
......@@ -2,6 +2,7 @@ dash~=1.11.0
hypothesis~=5.29.0
joblib~=0.16.0
keras-tuner~=1.0.1
kerastuner-tensorboard-logger~=0.2.3
matplotlib~=3.1.3
networkx~=2.4
numpy~=1.18.1
......
......@@ -13,7 +13,6 @@ from deepof.data import *
from deepof.models import *
from deepof.utils import *
from train_utils import *
from tensorboard.plugins.hparams import api as hp
from tensorflow import keras
parser = argparse.ArgumentParser(
......
......@@ -10,6 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t
from datetime import datetime
from kerastuner import BayesianOptimization
from kerastuner_tensorboard_logger import TensorBoardLogger
from typing import Tuple, Union, Any, List
import deepof.hypermodels
import deepof.model_utils
......@@ -176,6 +177,7 @@ def tune_search(
seed=42,
directory="BayesianOptx",
project_name=project_name,
logger=TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"),
)
print(tuner.search_space_summary())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment