From f851531ed3d6fbf7361fd1d63c378fa387b0c78f Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 1 Jul 2020 15:41:35 +0200
Subject: [PATCH] Implemented weight saving callback in model_training.py

---
 model_training.py | 42 ++++++++++++++++++++++++++++++++++++------
 source/models.py  | 25 ++++++++++++-------------
 2 files changed, 48 insertions(+), 19 deletions(-)

diff --git a/model_training.py b/model_training.py
index c177aa51..ff4bc3f0 100644
--- a/model_training.py
+++ b/model_training.py
@@ -288,22 +288,52 @@ input_dict_train = {
 
 input_dict_val = {
     "coords": coords2.preprocess(
-        window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+        window_size=11,
+        window_step=1,
+        scale=True,
+        random_state=42,
+        filter="gaussian",
+        sigma=110,
     ),
     "dists": distances2.preprocess(
-        window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+        window_size=11,
+        window_step=1,
+        scale=True,
+        random_state=42,
+        filter="gaussian",
+        sigma=110,
     ),
     "angles": angles2.preprocess(
-        window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+        window_size=11,
+        window_step=1,
+        scale=True,
+        random_state=42,
+        filter="gaussian",
+        sigma=110,
     ),
     "coords+dist": coords_distances2.preprocess(
-        window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+        window_size=11,
+        window_step=1,
+        scale=True,
+        random_state=42,
+        filter="gaussian",
+        sigma=110,
     ),
     "coords+angle": coords_angles2.preprocess(
-        window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+        window_size=11,
+        window_step=1,
+        scale=True,
+        random_state=42,
+        filter="gaussian",
+        sigma=110,
     ),
     "coords+dist+angle": coords_dist_angles2.preprocess(
-        window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+        window_size=11,
+        window_step=1,
+        scale=True,
+        random_state=42,
+        filter="gaussian",
+        sigma=110,
     ),
 }
 
diff --git a/source/models.py b/source/models.py
index 66651b9c..ccd41d9f 100644
--- a/source/models.py
+++ b/source/models.py
@@ -297,19 +297,6 @@ class SEQ_2_SEQ_GMVAE:
             activation=None,
         )(encoder)
 
-        # Define and control custom loss functions
-        kl_warmup_callback = False
-        if "ELBO" in self.loss:
-
-            kl_beta = K.variable(1.0, name="kl_beta")
-            kl_beta._trainable = False
-            if self.kl_warmup:
-                kl_warmup_callback = LambdaCallback(
-                    on_epoch_begin=lambda epoch, logs: K.set_value(
-                        kl_beta, K.min([epoch / self.kl_warmup, 1])
-                    )
-                )
-
         z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
         z = tfpl.DistributionLambda(
             lambda gauss: tfd.mixture.Mixture(
@@ -328,7 +315,19 @@ class SEQ_2_SEQ_GMVAE:
             activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
         )([z_cat, z_gauss])
 
+        # Define and control custom loss functions
+        kl_warmup_callback = False
         if "ELBO" in self.loss:
+
+            kl_beta = K.variable(1.0, name="kl_beta")
+            kl_beta._trainable = False
+            if self.kl_warmup:
+                kl_warmup_callback = LambdaCallback(
+                    on_epoch_begin=lambda epoch, logs: K.set_value(
+                        kl_beta, K.min([epoch / self.kl_warmup, 1])
+                    )
+                )
+
             z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
 
         mmd_warmup_callback = False
-- 
GitLab