From ae0129a96ce1875b69a730213ca061e630d7f2ac Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 13 Oct 2020 20:16:19 +0200
Subject: [PATCH] Added support for tensorboard HParams while tuning
 hyperparameters

---
 deepof/hypermodels.py | 68 ++++++++++++++++++++-----------------------
 deepof/train_model.py |  6 ++--
 deepof/train_utils.py | 29 ++++--------------
 3 files changed, 40 insertions(+), 63 deletions(-)

diff --git a/deepof/hypermodels.py b/deepof/hypermodels.py
index a1b3d4a3..ad09dca2 100644
--- a/deepof/hypermodels.py
+++ b/deepof/hypermodels.py
@@ -29,18 +29,18 @@ class SEQ_2_SEQ_AE(HyperModel):
         """Retrieve hyperparameters to tune"""
 
         conv_filters = hp.Int(
-            "units_conv", min_value=32, max_value=256, step=32, default=256
+            "units_conv", min_value=32, max_value=256, step=32, default=256,
         )
         lstm_units_1 = hp.Int(
-            "units_lstm", min_value=128, max_value=512, step=32, default=256
+            "units_lstm", min_value=128, max_value=512, step=32, default=256,
         )
         dense_2 = hp.Int(
-            "units_dense2", min_value=32, max_value=256, step=32, default=64
+            "units_dense2", min_value=32, max_value=256, step=32, default=64,
         )
         dropout_rate = hp.Float(
-            "dropout_rate", min_value=0.0, max_value=0.5, default=0.25, step=0.05
+            "dropout_rate", min_value=0.0, max_value=0.5, default=0.25, step=0.05,
         )
-        encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24)
+        encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24,)
 
         return conv_filters, lstm_units_1, dense_2, dropout_rate, encoding
 
@@ -74,10 +74,8 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         input_shape,
         entropy_reg_weight=0.0,
         huber_delta=100.0,
-        kl_warmup_epochs=0,
         learn_rate=1e-3,
         loss="ELBO+MMD",
-        mmd_warmup_epochs=0,
         number_of_components=-1,
         overlap_loss=False,
         predictor=0.0,
@@ -87,12 +85,8 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
         self.input_shape = input_shape
         self.entropy_reg_weight = entropy_reg_weight
         self.huber_delta = huber_delta
-        self.kl_warmup = kl_warmup_epochs
-        self.kl_warmup_callback = None
         self.learn_rate = learn_rate
         self.loss = loss
-        self.mmd_warmup = mmd_warmup_epochs
-        self.mmd_warmup_callback = None
         self.number_of_components = number_of_components
         self.overlap_loss = overlap_loss
         self.predictor = predictor
@@ -107,33 +101,39 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
 
         # Architectural hyperparameters
         conv_filters = hp.Int(
-            "units_conv", min_value=32, max_value=256, step=32, default=256
+            "units_conv", min_value=32, max_value=256, step=32, default=256,
         )
         lstm_units_1 = hp.Int(
-            "units_lstm", min_value=128, max_value=512, step=32, default=256
+            "units_lstm", min_value=128, max_value=512, step=32, default=256,
         )
         dense_2 = hp.Int(
-            "units_dense2", min_value=32, max_value=256, step=32, default=64
+            "units_dense2", min_value=32, max_value=256, step=32, default=64,
         )
         dropout_rate = hp.Float(
-            "dropout_rate", min_value=0.0, max_value=0.5, default=0.25, step=0.05
+            "dropout_rate",
+            min_value=0.0,
+            max_value=0.5,
+            default=0.25,
+            sampling="linear",
+        )
+        encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24,)
+        k = hp.Int(
+            "n_components",
+            min_value=1,
+            max_value=15,
+            step=1,
+            default=self.number_of_components,
+            sampling="linear",
         )
-        encoding = hp.Int("encoding", min_value=16, max_value=64, step=8, default=24)
-
-        # Conditional hyperparameters
-        for placeholder, hparam in zip(
-            ["number_of_components", "kl_warmup", "mmd_warmup"],
-            [
-                hp.Int("n_components", min_value=1, max_value=15, step=1, default=5),
-                hp.Int("kl_warmup", min_value=0, max_value=20, step=5, default=10),
-                hp.Int("mmd_warmup", min_value=0, max_value=20, step=5, default=10),
-            ],
-        ):
-
-            if getattr(self, placeholder) == -1:
-                setattr(self, placeholder, hparam)
 
-        return conv_filters, lstm_units_1, dense_2, dropout_rate, encoding
+        return (
+            conv_filters,
+            lstm_units_1,
+            dense_2,
+            dropout_rate,
+            encoding,
+            k,
+        )
 
     def build(self, hp):
         """Overrides Hypermodel's build method"""
@@ -145,6 +145,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             dense_2,
             dropout_rate,
             encoding,
+            k,
         ) = self.get_hparams(hp)
 
         gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE(
@@ -157,17 +158,12 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
             },
             entropy_reg_weight=self.entropy_reg_weight,
             huber_delta=self.huber_delta,
-            kl_warmup_epochs=self.kl_warmup,
             loss=self.loss,
-            mmd_warmup_epochs=self.mmd_warmup,
-            number_of_components=self.number_of_components,
+            number_of_components=k,
             overlap_loss=self.overlap_loss,
             predictor=self.predictor,
         ).build(self.input_shape)[3:]
 
-        self.kl_warmup_callback = kl_warmup_callback
-        self.mmd_warmup_callback = mmd_warmup_callback
-
         return gmvaep
 
 
diff --git a/deepof/train_model.py b/deepof/train_model.py
index 61f21127..4f27def2 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -278,7 +278,7 @@ if not tune:
         tf.keras.backend.clear_session()
 
         run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
-            X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
+            X_train, batch_size, variational, predictor, loss,
         )
 
         if not variational:
@@ -373,7 +373,7 @@ else:
     hyp = "S2SGMVAE" if variational else "S2SAE"
 
     run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
-        X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
+        X_train, batch_size, variational, predictor, loss
     )
 
     best_hyperparameters, best_model = tune_search(
@@ -382,9 +382,7 @@ else:
         bayopt_trials=bayopt_trials,
         hypermodel=hyp,
         k=k,
-        kl_wu=kl_wu,
         loss=loss,
-        mmd_wu=mmd_wu,
         overlap_loss=overlap_loss,
         predictor=predictor,
         project_name="{}-based_{}_BAYESIAN_OPT".format(input_type, hyp),
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 919ed0e0..f2dcfb12 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -10,6 +10,7 @@ Simple utility functions used in deepof example scripts. These are not part of t
 from datetime import datetime
 
 from kerastuner import BayesianOptimization
+from kerastuner import HyperParameters
 from kerastuner_tensorboard_logger import TensorBoardLogger
 from typing import Tuple, Union, Any, List
 import deepof.hypermodels
@@ -19,6 +20,8 @@ import os
 import pickle
 import tensorflow as tf
 
+hp = HyperParameters()
+
 
 def load_hparams(hparams):
     """Loads hyperparameters from a custom dictionary pickled on disc.
@@ -58,14 +61,7 @@ def load_treatments(train_path):
 
 
 def get_callbacks(
-    X_train: np.array,
-    batch_size: int,
-    variational: bool,
-    predictor: float,
-    k: int,
-    loss: str,
-    kl_wu: int,
-    mmd_wu: int,
+    X_train: np.array, batch_size: int, variational: bool, predictor: float, loss: str,
 ) -> Tuple:
     """Generates callbacks for model training, including:
         - run_ID: run name, with coarse parameter details;
@@ -73,13 +69,10 @@ def get_callbacks(
         - cp_callback: for checkpoint saving,
         - onecycle: for learning rate scheduling"""
 
-    run_ID = "{}{}{}{}{}{}_{}".format(
+    run_ID = "{}{}{}_{}".format(
         ("GMVAE" if variational else "AE"),
         ("P" if predictor > 0 and variational else ""),
-        ("_components={}".format(k) if variational else ""),
         ("_loss={}".format(loss) if variational else ""),
-        ("_kl_warmup={}".format(kl_wu) if variational else ""),
-        ("_mmd_warmup={}".format(mmd_wu) if variational else ""),
         datetime.now().strftime("%Y%m%d-%H%M%S"),
     )
 
@@ -109,9 +102,7 @@ def tune_search(
     bayopt_trials: int,
     hypermodel: str,
     k: int,
-    kl_wu: int,
     loss: str,
-    mmd_wu: int,
     overlap_loss: float,
     predictor: float,
     project_name: str,
@@ -128,9 +119,7 @@ def tune_search(
             - hypermodel (str): hypermodel to load. Must be one of S2SAE (plain autoencoder)
             or S2SGMVAE (Gaussian Mixture Variational autoencoder).
             - k (int) number of components of the Gaussian Mixture
-            - kl_wu (int): number of epochs for KL divergence warm up
             - loss (str): one of [ELBO, MMD, ELBO+MMD]
-            - mmd_wu (int): number of epochs for MMD warm up
             - overlap_loss (float): assigns as weight to an extra loss term which
             penalizes overlap between GM components
             - predictor (float): adds an extra regularizing neural network to the model,
@@ -153,19 +142,12 @@ def tune_search(
     elif hypermodel == "S2SGMVAE":
         hypermodel = deepof.hypermodels.SEQ_2_SEQ_GMVAE(
             input_shape=train.shape,
-            kl_warmup_epochs=kl_wu,
             loss=loss,
-            mmd_warmup_epochs=mmd_wu,
             number_of_components=k,
             overlap_loss=overlap_loss,
             predictor=predictor,
         )
 
-        # if "ELBO" in loss and kl_wu > 0:
-        #     callbacks.append(hypermodel.kl_warmup_callback)
-        # if "MMD" in loss and mmd_wu > 0:
-        #     callbacks.append(hypermodel.mmd_warmup_callback)
-
     else:
         return False
 
@@ -178,6 +160,7 @@ def tune_search(
         objective="val_mae",
         project_name=project_name,
         seed=42,
+        tune_new_entries=True,
     )
 
     print(tuner.search_space_summary())
-- 
GitLab