Skip to content
Snippets Groups Projects
Commit 14a77ace authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added KL and MMD warmup draft in SEQ2SEQ_VAE and SEQ2SEQ_VAEP models

parent ced17a62
No related branches found
No related tags found
No related merge requests found
...@@ -129,7 +129,7 @@ class KLDivergenceLayer(Layer): ...@@ -129,7 +129,7 @@ class KLDivergenceLayer(Layer):
kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
self.add_loss(beta * K.mean(kL_batch), inputs=inputs) self.add_loss(self.beta * K.mean(kL_batch), inputs=inputs)
return inputs return inputs
...@@ -148,6 +148,6 @@ class MMDiscrepancyLayer(Layer): ...@@ -148,6 +148,6 @@ class MMDiscrepancyLayer(Layer):
true_samples = K.random_normal(K.shape(z)) true_samples = K.random_normal(K.shape(z))
mmd_batch = compute_mmd(true_samples, z) mmd_batch = compute_mmd(true_samples, z)
self.add_loss(beta * K.mean(mmd_batch), inputs=z) self.add_loss(self.beta * K.mean(mmd_batch), inputs=z)
return z return z
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment