Commit b7acdc44 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented VAEP hypermodel in hypermodels.py

parent e2f2f4e6
......@@ -118,8 +118,9 @@ class KLDivergenceLayer(Layer):
to the final model loss.
"""
def __init__(self, *args, **kwargs):
def __init__(self, beta=1, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def call(self, inputs, **kwargs):
......@@ -128,7 +129,7 @@ class KLDivergenceLayer(Layer):
kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
self.add_loss(K.mean(kL_batch), inputs=inputs)
self.add_loss(beta * K.mean(kL_batch), inputs=inputs)
return inputs
......@@ -138,14 +139,15 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def __init__(self, *args, **kwargs):
def __init__(self, beta=1, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def call(self, z, **kwargs):
true_samples = K.random_normal(K.shape(z))
mmd_batch = compute_mmd(true_samples, z)
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_loss(beta * K.mean(mmd_batch), inputs=z)
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment