From 442ed01be4ed015f4437621c2227dd542e0adff8 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 5 Jun 2020 12:41:50 +0200
Subject: [PATCH] Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py

---
 source/model_utils.py | 6 ++++--
 source/models.py      | 4 ++++
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/source/model_utils.py b/source/model_utils.py
index 0abc4431..75a31206 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -130,9 +130,10 @@ class KLDivergenceLayer(Layer):
 
     def call(self, inputs, **kwargs):
         mu, log_var = inputs
-        kL_batch = -0.5 * self.beta * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
+        KL_batch = -0.5 * self.beta * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
 
-        self.add_loss(K.mean(kL_batch), inputs=inputs)
+        self.add_loss(K.mean(KL_batch), inputs=inputs)
+        self.add_metric(KL_batch, aggregation="mean", name="kl_divergence")
         self.add_metric(self.beta, aggregation="mean", name="kl_rate")
 
         return inputs
@@ -158,6 +159,7 @@ class MMDiscrepancyLayer(Layer):
         mmd_batch = self.beta * compute_mmd(true_samples, z)
 
         self.add_loss(K.mean(mmd_batch), inputs=z)
+        self.add_metric(mmd_batch, aggregation="mean", name="mmd")
         self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
 
         return z
diff --git a/source/models.py b/source/models.py
index 24d50ee3..f978d651 100644
--- a/source/models.py
+++ b/source/models.py
@@ -277,6 +277,7 @@ class SEQ_2_SEQ_VAE:
         if "ELBO" in self.loss:
 
             kl_beta = K.variable(1.0, name="kl_beta")
+            kl_beta._trainable = False
             if self.kl_warmup:
 
                 kl_warmup_callback = LambdaCallback(
@@ -293,6 +294,7 @@ class SEQ_2_SEQ_VAE:
         if "MMD" in self.loss:
 
             mmd_beta = K.variable(1.0, name="mmd_beta")
+            mmd_beta._trainable = False
             if self.mmd_warmup:
 
                 mmd_warmup_callback = LambdaCallback(
@@ -480,6 +482,7 @@ class SEQ_2_SEQ_VAEP:
         if "ELBO" in self.loss:
 
             kl_beta = K.variable(1.0, name="kl_beta")
+            kl_beta._trainable = False
             if self.kl_warmup:
                 kl_warmup_callback = LambdaCallback(
                     on_epoch_begin=lambda epoch, logs: K.set_value(
@@ -495,6 +498,7 @@ class SEQ_2_SEQ_VAEP:
         if "MMD" in self.loss:
 
             mmd_beta = K.variable(1.0, name="mmd_beta")
+            mmd_beta._trainable = False
             if self.mmd_warmup:
                 mmd_warmup_callback = LambdaCallback(
                     on_epoch_begin=lambda epoch, logs: K.set_value(
-- 
GitLab