diff --git a/source/model_utils.py b/source/model_utils.py index 0abc4431831c215d8c88db3006cd99e21e9a5816..75a312068b04814613f935086814df6dd9da3b4b 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -130,9 +130,10 @@ class KLDivergenceLayer(Layer): def call(self, inputs, **kwargs): mu, log_var = inputs - kL_batch = -0.5 * self.beta * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) + KL_batch = -0.5 * self.beta * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) - self.add_loss(K.mean(kL_batch), inputs=inputs) + self.add_loss(K.mean(KL_batch), inputs=inputs) + self.add_metric(KL_batch, aggregation="mean", name="kl_divergence") self.add_metric(self.beta, aggregation="mean", name="kl_rate") return inputs @@ -158,6 +159,7 @@ class MMDiscrepancyLayer(Layer): mmd_batch = self.beta * compute_mmd(true_samples, z) self.add_loss(K.mean(mmd_batch), inputs=z) + self.add_metric(mmd_batch, aggregation="mean", name="mmd") self.add_metric(self.beta, aggregation="mean", name="mmd_rate") return z diff --git a/source/models.py b/source/models.py index 24d50ee3a3e05cacb47e15d7d6fac138e78c1420..f978d651e286b23ae3ad607eb4ede1a8c5466fde 100644 --- a/source/models.py +++ b/source/models.py @@ -277,6 +277,7 @@ class SEQ_2_SEQ_VAE: if "ELBO" in self.loss: kl_beta = K.variable(1.0, name="kl_beta") + kl_beta._trainable = False if self.kl_warmup: kl_warmup_callback = LambdaCallback( @@ -293,6 +294,7 @@ class SEQ_2_SEQ_VAE: if "MMD" in self.loss: mmd_beta = K.variable(1.0, name="mmd_beta") + mmd_beta._trainable = False if self.mmd_warmup: mmd_warmup_callback = LambdaCallback( @@ -480,6 +482,7 @@ class SEQ_2_SEQ_VAEP: if "ELBO" in self.loss: kl_beta = K.variable(1.0, name="kl_beta") + kl_beta._trainable = False if self.kl_warmup: kl_warmup_callback = LambdaCallback( on_epoch_begin=lambda epoch, logs: K.set_value( @@ -495,6 +498,7 @@ class SEQ_2_SEQ_VAEP: if "MMD" in self.loss: mmd_beta = K.variable(1.0, name="mmd_beta") + mmd_beta._trainable = False if self.mmd_warmup: mmd_warmup_callback = LambdaCallback( on_epoch_begin=lambda epoch, logs: K.set_value(