diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py index 672d96e0685256a3fb7d238633ffd9a096c070d9..ba09f67b0d4025f8165a0ecf0d3042fef7867f61 100644 --- a/deepof/hypermodels.py +++ b/deepof/hypermodels.py @@ -24,10 +24,10 @@ class SEQ_2_SEQ_AE(HyperModel): super().__init__() self.input_shape = input_shape - def build(self, hp): - """Overrides Hypermodel's build method""" + @staticmethod + def get_hparams(hp): + """Retrieve hyperparameters to tune""" - # HYPERPARAMETERS TO TUNE conv_filters = hp.Int( "units_conv", min_value=32, max_value=256, step=32, default=256 ) @@ -42,6 +42,16 @@ class SEQ_2_SEQ_AE(HyperModel): ) encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24) + return conv_filters, lstm_units_1, dense_2, dropout_rate, encoding + + def build(self, hp): + """Overrides Hypermodel's build method""" + + # HYPERPARAMETERS TO TUNE + conv_filters, lstm_units_1, dense_2, dropout_rate, encoding = self.get_hparams( + hp + ) + # INSTANCIATED MODEL model = deepof.models.SEQ_2_SEQ_AE( architecture_hparams={ @@ -92,10 +102,10 @@ class SEQ_2_SEQ_GMVAE(HyperModel): "ELBO" in self.loss or "MMD" in self.loss ), "loss must be one of ELBO, MMD or ELBO+MMD (default)" - def build(self, hp): - """Overrides Hypermodel's build method""" + @staticmethod + def get_hparams(hp): + """Retrieve hyperparameters to tune""" - # Hyperparameters to tune conv_filters = hp.Int( "units_conv", min_value=32, max_value=256, step=32, default=256 ) @@ -110,6 +120,16 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ) encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24) + return conv_filters, lstm_units_1, dense_2, dropout_rate, encoding + + def build(self, hp): + """Overrides Hypermodel's build method""" + + # Hyperparameters to tune + conv_filters, lstm_units_1, dense_2, dropout_rate, encoding = self.get_hparams( + hp + ) + gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE( architecture_hparams={ "units_conv": conv_filters, diff --git a/deepof/requirements.txt b/deepof/requirements.txt index e9111747ba2e13fc8c697d11bc2107adb09d23ba..4f7b7cbe82addcfb95825e0ba2313ffcf86eec64 100644 --- a/deepof/requirements.txt +++ b/deepof/requirements.txt @@ -2,6 +2,7 @@ dash~=1.11.0 hypothesis~=5.29.0 joblib~=0.16.0 keras-tuner~=1.0.1 +kerastuner-tensorboard-logger~=0.2.3 matplotlib~=3.1.3 networkx~=2.4 numpy~=1.18.1 diff --git a/deepof/train_model.py b/deepof/train_model.py index 1d2614737d441fbcb006c03389e9d81cd002e000..162bef3f8eb1677b323cda34a1902f0ceab0cf3e 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -13,7 +13,6 @@ from deepof.data import * from deepof.models import * from deepof.utils import * from train_utils import * -from tensorboard.plugins.hparams import api as hp from tensorflow import keras parser = argparse.ArgumentParser( diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 7fff414f9879e4e1bad0130f28dc3b99e427eb75..a1b5990cb13a5f4b3d49b70a97778e9013289638 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -10,6 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t from datetime import datetime from kerastuner import BayesianOptimization +from kerastuner_tensorboard_logger import TensorBoardLogger from typing import Tuple, Union, Any, List import deepof.hypermodels import deepof.model_utils @@ -176,6 +177,7 @@ def tune_search( seed=42, directory="BayesianOptx", project_name=project_name, + logger=TensorBoardLogger(metrics=["val_mae"], logdir="./logs/hparams"), ) print(tuner.search_space_summary())