From f851531ed3d6fbf7361fd1d63c378fa387b0c78f Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 1 Jul 2020 15:41:35 +0200
Subject: [PATCH] Implemented weight saving callback in model_training.py
---
model_training.py | 42 ++++++++++++++++++++++++++++++++++++------
source/models.py | 25 ++++++++++++-------------
2 files changed, 48 insertions(+), 19 deletions(-)
diff --git a/model_training.py b/model_training.py
index c177aa51..ff4bc3f0 100644
--- a/model_training.py
+++ b/model_training.py
@@ -288,22 +288,52 @@ input_dict_train = {
input_dict_val = {
"coords": coords2.preprocess(
- window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+ window_size=11,
+ window_step=1,
+ scale=True,
+ random_state=42,
+ filter="gaussian",
+ sigma=110,
),
"dists": distances2.preprocess(
- window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+ window_size=11,
+ window_step=1,
+ scale=True,
+ random_state=42,
+ filter="gaussian",
+ sigma=110,
),
"angles": angles2.preprocess(
- window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+ window_size=11,
+ window_step=1,
+ scale=True,
+ random_state=42,
+ filter="gaussian",
+ sigma=110,
),
"coords+dist": coords_distances2.preprocess(
- window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+ window_size=11,
+ window_step=1,
+ scale=True,
+ random_state=42,
+ filter="gaussian",
+ sigma=110,
),
"coords+angle": coords_angles2.preprocess(
- window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+ window_size=11,
+ window_step=1,
+ scale=True,
+ random_state=42,
+ filter="gaussian",
+ sigma=110,
),
"coords+dist+angle": coords_dist_angles2.preprocess(
- window_size=11, window_step=1, scale=True, random_state=42, filter="gauss"
+ window_size=11,
+ window_step=1,
+ scale=True,
+ random_state=42,
+ filter="gaussian",
+ sigma=110,
),
}
diff --git a/source/models.py b/source/models.py
index 66651b9c..ccd41d9f 100644
--- a/source/models.py
+++ b/source/models.py
@@ -297,19 +297,6 @@ class SEQ_2_SEQ_GMVAE:
activation=None,
)(encoder)
- # Define and control custom loss functions
- kl_warmup_callback = False
- if "ELBO" in self.loss:
-
- kl_beta = K.variable(1.0, name="kl_beta")
- kl_beta._trainable = False
- if self.kl_warmup:
- kl_warmup_callback = LambdaCallback(
- on_epoch_begin=lambda epoch, logs: K.set_value(
- kl_beta, K.min([epoch / self.kl_warmup, 1])
- )
- )
-
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z = tfpl.DistributionLambda(
lambda gauss: tfd.mixture.Mixture(
@@ -328,7 +315,19 @@ class SEQ_2_SEQ_GMVAE:
activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
)([z_cat, z_gauss])
+ # Define and control custom loss functions
+ kl_warmup_callback = False
if "ELBO" in self.loss:
+
+ kl_beta = K.variable(1.0, name="kl_beta")
+ kl_beta._trainable = False
+ if self.kl_warmup:
+ kl_warmup_callback = LambdaCallback(
+ on_epoch_begin=lambda epoch, logs: K.set_value(
+ kl_beta, K.min([epoch / self.kl_warmup, 1])
+ )
+ )
+
z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
mmd_warmup_callback = False
--
GitLab