Commit 442ed01b authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py

parent 71d92b26
...@@ -130,9 +130,10 @@ class KLDivergenceLayer(Layer): ...@@ -130,9 +130,10 @@ class KLDivergenceLayer(Layer):
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
mu, log_var = inputs mu, log_var = inputs
kL_batch = -0.5 * self.beta * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) KL_batch = -0.5 * self.beta * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
self.add_loss(K.mean(kL_batch), inputs=inputs) self.add_loss(K.mean(KL_batch), inputs=inputs)
self.add_metric(KL_batch, aggregation="mean", name="kl_divergence")
self.add_metric(self.beta, aggregation="mean", name="kl_rate") self.add_metric(self.beta, aggregation="mean", name="kl_rate")
return inputs return inputs
...@@ -158,6 +159,7 @@ class MMDiscrepancyLayer(Layer): ...@@ -158,6 +159,7 @@ class MMDiscrepancyLayer(Layer):
mmd_batch = self.beta * compute_mmd(true_samples, z) mmd_batch = self.beta * compute_mmd(true_samples, z)
self.add_loss(K.mean(mmd_batch), inputs=z) self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate") self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
return z return z
...@@ -277,6 +277,7 @@ class SEQ_2_SEQ_VAE: ...@@ -277,6 +277,7 @@ class SEQ_2_SEQ_VAE:
if "ELBO" in self.loss: if "ELBO" in self.loss:
kl_beta = K.variable(1.0, name="kl_beta") kl_beta = K.variable(1.0, name="kl_beta")
kl_beta._trainable = False
if self.kl_warmup: if self.kl_warmup:
kl_warmup_callback = LambdaCallback( kl_warmup_callback = LambdaCallback(
...@@ -293,6 +294,7 @@ class SEQ_2_SEQ_VAE: ...@@ -293,6 +294,7 @@ class SEQ_2_SEQ_VAE:
if "MMD" in self.loss: if "MMD" in self.loss:
mmd_beta = K.variable(1.0, name="mmd_beta") mmd_beta = K.variable(1.0, name="mmd_beta")
mmd_beta._trainable = False
if self.mmd_warmup: if self.mmd_warmup:
mmd_warmup_callback = LambdaCallback( mmd_warmup_callback = LambdaCallback(
...@@ -480,6 +482,7 @@ class SEQ_2_SEQ_VAEP: ...@@ -480,6 +482,7 @@ class SEQ_2_SEQ_VAEP:
if "ELBO" in self.loss: if "ELBO" in self.loss:
kl_beta = K.variable(1.0, name="kl_beta") kl_beta = K.variable(1.0, name="kl_beta")
kl_beta._trainable = False
if self.kl_warmup: if self.kl_warmup:
kl_warmup_callback = LambdaCallback( kl_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value( on_epoch_begin=lambda epoch, logs: K.set_value(
...@@ -495,6 +498,7 @@ class SEQ_2_SEQ_VAEP: ...@@ -495,6 +498,7 @@ class SEQ_2_SEQ_VAEP:
if "MMD" in self.loss: if "MMD" in self.loss:
mmd_beta = K.variable(1.0, name="mmd_beta") mmd_beta = K.variable(1.0, name="mmd_beta")
mmd_beta._trainable = False
if self.mmd_warmup: if self.mmd_warmup:
mmd_warmup_callback = LambdaCallback( mmd_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value( on_epoch_begin=lambda epoch, logs: K.set_value(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment