Commit 27b68380 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for KLDivergence and MMDiscrepancy layers

parent e40304a1
......@@ -287,23 +287,31 @@ class DenseTranspose(Layer):
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
"""
Identity transform layer that adds KL Divergence
to the final model loss.
"""
def __init__(self, *args, **kwargs):
self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def get_config(self):
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update(
{"is_placeholder": self.is_placeholder,}
)
config.update({"is_placeholder": self.is_placeholder})
return config
def call(self, distribution_a):
"""Updates Layer's call method"""
kl_batch = self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a])
self.add_metric(
kl_batch, aggregation="mean", name="kl_divergence",
)
# noinspection PyProtectedMember
self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
return distribution_a
......@@ -323,6 +331,8 @@ class MMDiscrepancyLayer(Layer):
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def get_config(self):
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"beta": self.beta})
......@@ -330,8 +340,10 @@ class MMDiscrepancyLayer(Layer):
return config
def call(self, z, **kwargs):
"""Updates Layer's call method"""
true_samples = self.prior.sample(self.batch_size)
mmd_batch = self.beta * compute_mmd([true_samples, z])
mmd_batch = self.beta * compute_mmd((true_samples, z))
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
......
......@@ -202,7 +202,7 @@ class SEQ_2_SEQ_GMVAE:
if self.prior == "standard_normal":
init_means = far_away_uniform_initialiser(
shape=[self.number_of_components, self.ENCODING], minval=0, maxval=5
shape=(self.number_of_components, self.ENCODING), minval=0, maxval=5
)
self.prior = tfd.mixture.Mixture(
......
......@@ -15,12 +15,15 @@ from hypothesis.extra.numpy import arrays
import deepof.model_utils
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.framework.ops import EagerTensor
# For coverage.py to work with @tf.function decorated functions and methods,
# graph execution is disabled when running this script with pytest
tf.config.experimental_run_functions_eagerly(True)
tfpl = tfp.layers
tfd = tfp.distributions
@settings(deadline=None)
......@@ -150,14 +153,66 @@ def test_dense_transpose():
assert type(fit) == tf.python.keras.callbacks.History
# def test_KLDivergenceLayer():
# pass
#
#
# @settings(deadline=None)
# @given()
# def test_mmdiscrepancy_layer():
# pass
def test_KLDivergenceLayer():
X = tf.random.uniform([1500, 10], 0, 10)
y = np.random.randint(0, 2, [1500, 1])
prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(10), scale=1,), reinterpreted_batch_ndims=1,
)
dense_1 = tf.keras.layers.Dense(10)
i = tf.keras.layers.Input(shape=(10,))
d = dense_1(i)
x = tfpl.DistributionLambda(
lambda dense: tfd.Independent(
tfd.Normal(loc=dense, scale=1,), reinterpreted_batch_ndims=1,
)
)(d)
x = deepof.model_utils.KLDivergenceLayer(
prior, weight=tf.keras.backend.variable(1.0, name="kl_beta")
)(x)
test_model = tf.keras.Model(i, x)
test_model.compile(
loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(),
)
fit = test_model.fit(X, y, epochs=10, batch_size=100)
assert type(fit) == tf.python.keras.callbacks.History
def test_MMDiscrepancyLayer():
X = tf.random.uniform([1500, 10], 0, 10)
y = np.random.randint(0, 2, [1500, 1])
prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(10), scale=1, ), reinterpreted_batch_ndims=1,
)
dense_1 = tf.keras.layers.Dense(10)
i = tf.keras.layers.Input(shape=(10,))
d = dense_1(i)
x = tfpl.DistributionLambda(
lambda dense: tfd.Independent(
tfd.Normal(loc=dense, scale=1, ), reinterpreted_batch_ndims=1,
)
)(d)
x = deepof.model_utils.MMDiscrepancyLayer(
100, prior, beta=tf.keras.backend.variable(1.0, name="kl_beta")
)(x)
test_model = tf.keras.Model(i, x)
test_model.compile(
loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(),
)
fit = test_model.fit(X, y, epochs=10, batch_size=100)
assert type(fit) == tf.python.keras.callbacks.History
#
#
# @settings(deadline=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment