From f851531ed3d6fbf7361fd1d63c378fa387b0c78f Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Wed, 1 Jul 2020 15:41:35 +0200 Subject: [PATCH] Implemented weight saving callback in model_training.py --- model_training.py | 42 ++++++++++++++++++++++++++++++++++++------ source/models.py | 25 ++++++++++++------------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/model_training.py b/model_training.py index c177aa51..ff4bc3f0 100644 --- a/model_training.py +++ b/model_training.py @@ -288,22 +288,52 @@ input_dict_train = { input_dict_val = { "coords": coords2.preprocess( - window_size=11, window_step=1, scale=True, random_state=42, filter="gauss" + window_size=11, + window_step=1, + scale=True, + random_state=42, + filter="gaussian", + sigma=110, ), "dists": distances2.preprocess( - window_size=11, window_step=1, scale=True, random_state=42, filter="gauss" + window_size=11, + window_step=1, + scale=True, + random_state=42, + filter="gaussian", + sigma=110, ), "angles": angles2.preprocess( - window_size=11, window_step=1, scale=True, random_state=42, filter="gauss" + window_size=11, + window_step=1, + scale=True, + random_state=42, + filter="gaussian", + sigma=110, ), "coords+dist": coords_distances2.preprocess( - window_size=11, window_step=1, scale=True, random_state=42, filter="gauss" + window_size=11, + window_step=1, + scale=True, + random_state=42, + filter="gaussian", + sigma=110, ), "coords+angle": coords_angles2.preprocess( - window_size=11, window_step=1, scale=True, random_state=42, filter="gauss" + window_size=11, + window_step=1, + scale=True, + random_state=42, + filter="gaussian", + sigma=110, ), "coords+dist+angle": coords_dist_angles2.preprocess( - window_size=11, window_step=1, scale=True, random_state=42, filter="gauss" + window_size=11, + window_step=1, + scale=True, + random_state=42, + filter="gaussian", + sigma=110, ), } diff --git a/source/models.py b/source/models.py index 66651b9c..ccd41d9f 100644 --- a/source/models.py +++ b/source/models.py @@ -297,19 +297,6 @@ class SEQ_2_SEQ_GMVAE: activation=None, )(encoder) - # Define and control custom loss functions - kl_warmup_callback = False - if "ELBO" in self.loss: - - kl_beta = K.variable(1.0, name="kl_beta") - kl_beta._trainable = False - if self.kl_warmup: - kl_warmup_callback = LambdaCallback( - on_epoch_begin=lambda epoch, logs: K.set_value( - kl_beta, K.min([epoch / self.kl_warmup, 1]) - ) - ) - z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) z = tfpl.DistributionLambda( lambda gauss: tfd.mixture.Mixture( @@ -328,7 +315,19 @@ class SEQ_2_SEQ_GMVAE: activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0), )([z_cat, z_gauss]) + # Define and control custom loss functions + kl_warmup_callback = False if "ELBO" in self.loss: + + kl_beta = K.variable(1.0, name="kl_beta") + kl_beta._trainable = False + if self.kl_warmup: + kl_warmup_callback = LambdaCallback( + on_epoch_begin=lambda epoch, logs: K.set_value( + kl_beta, K.min([epoch / self.kl_warmup, 1]) + ) + ) + z = KLDivergenceLayer(self.prior, weight=kl_beta)(z) mmd_warmup_callback = False -- GitLab