Skip to content
Snippets Groups Projects
Commit b7acdc44 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Implemented VAEP hypermodel in hypermodels.py

parent e2f2f4e6
No related branches found
No related tags found
No related merge requests found
......@@ -118,8 +118,9 @@ class KLDivergenceLayer(Layer):
to the final model loss.
"""
def __init__(self, *args, **kwargs):
def __init__(self, beta=1, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def call(self, inputs, **kwargs):
......@@ -128,7 +129,7 @@ class KLDivergenceLayer(Layer):
kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
self.add_loss(K.mean(kL_batch), inputs=inputs)
self.add_loss(beta * K.mean(kL_batch), inputs=inputs)
return inputs
......@@ -138,14 +139,15 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def __init__(self, *args, **kwargs):
def __init__(self, beta=1, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def call(self, z, **kwargs):
true_samples = K.random_normal(K.shape(z))
mmd_batch = compute_mmd(true_samples, z)
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_loss(beta * K.mean(mmd_batch), inputs=z)
return z
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment