From f78f765f14a2598d659a8e48a934cd913865ff2b Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 3 Jul 2020 19:03:47 +0200
Subject: [PATCH] Implemented weight saving callback in model_training.py

---
 model_training.py     | 12 ++++++++----
 source/model_utils.py | 24 ++++++++++++++++++++++++
 source/models.py      | 16 +++++++++++++++-
 3 files changed, 47 insertions(+), 5 deletions(-)

diff --git a/model_training.py b/model_training.py
index 8e76b24b..b8dbb19b 100644
--- a/model_training.py
+++ b/model_training.py
@@ -371,10 +371,10 @@ if not variational:
         validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
         callbacks=[
             tensorboard_callback,
+            cp_callback,
             tf.keras.callbacks.EarlyStopping(
                 "val_mae", patience=5, restore_best_weights=True
             ),
-            cp_callback,
         ],
     )
 
@@ -384,6 +384,8 @@ else:
         generator,
         grouper,
         gmvaep,
+        dead_neuron_rate_callback,
+        silhouette_callback,
         kl_warmup_callback,
         mmd_warmup_callback,
     ) = SEQ_2_SEQ_GMVAE(
@@ -401,15 +403,17 @@ else:
 
     callbacks_ = [
         tensorboard_callback,
+        cp_callback,
+        dead_neuron_rate_callback,
+        silhouette_callback,
         tf.keras.callbacks.EarlyStopping(
             "val_mae", patience=5, restore_best_weights=True
         ),
-        cp_callback,
     ]
 
-    if "ELBO" in loss:
+    if "ELBO" in loss and kl_wu > 0:
         callbacks_.append(kl_warmup_callback)
-    if "MMD" in loss:
+    if "MMD" in loss and mmd_wu > 0:
         callbacks_.append(mmd_warmup_callback)
 
     if not predictor:
diff --git a/source/model_utils.py b/source/model_utils.py
index 32f3112f..17447fa2 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -1,6 +1,7 @@
 # @author lucasmiranda42
 
 from keras import backend as K
+from sklearn.metrics import silhouette_score
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
 import tensorflow as tf
@@ -150,3 +151,26 @@ class MMDiscrepancyLayer(Layer):
         self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
 
         return z
+
+
+class Latent_space_control(Layer):
+    """ Identity layer that adds latent space and clustering stats
+     to the metrics compiled by the model
+     """
+
+    def __init__(self, *args, **kwargs):
+        super(Latent_space_control, self).__init__(*args, **kwargs)
+
+    def call(self, z, z_gauss, z_cat, **kwargs):
+
+        # Adds metric that monitors dead neurons in the latent space
+        self.add_metric(
+            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
+        )
+
+        # Adds Silhouette score controling overlap between clusters
+        hard_labels = tf.math.argmax(z_cat, axis=1)
+        silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
+        self.add_metric(silhouette, aggregation="mean", name="silhouette")
+
+        return z
diff --git a/source/models.py b/source/models.py
index d1e0a197..a54bd76d 100644
--- a/source/models.py
+++ b/source/models.py
@@ -344,6 +344,18 @@ class SEQ_2_SEQ_GMVAE:
 
             z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
 
+        # z = Latent_space_control()(z, z_gauss, z_cat)
+
+        # Latent space callback to control dead (zero) dimensions in the latent space
+        dead_neuron_rate_callback = LambdaCallback(
+            on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
+        )
+
+        # Latent space callback to control the latent silhouette clustering index
+        silhouette_callback = LambdaCallback(
+            on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
+        )
+
         # Define and instantiate generator
         generator = Model_D1(z)
         generator = Model_B1(generator)
@@ -429,6 +441,8 @@ class SEQ_2_SEQ_GMVAE:
             generator,
             grouper,
             gmvaep,
+            dead_neuron_rate_callback,
+            silhouette_callback,
             kl_warmup_callback,
             mmd_warmup_callback,
         )
@@ -437,7 +451,7 @@ class SEQ_2_SEQ_GMVAE:
 # TODO:
 #       - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
 #       - Clustering metrics for model selection and aid training (eg early stopping)
-#           - Silhouette / likelihood / classifier accuracy metrics
+#           - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics
 #       - design clustering-conscious hyperparameter tuing pipeline
 
 # TODO (in the non-immediate future):
-- 
GitLab