Skip to content
Snippets Groups Projects
Commit 27b68380 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added tests for KLDivergence and MMDiscrepancy layers

parent e40304a1
No related branches found
No related tags found
No related merge requests found
......@@ -287,23 +287,31 @@ class DenseTranspose(Layer):
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
"""
Identity transform layer that adds KL Divergence
to the final model loss.
"""
def __init__(self, *args, **kwargs):
self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def get_config(self):
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update(
{"is_placeholder": self.is_placeholder,}
)
config.update({"is_placeholder": self.is_placeholder})
return config
def call(self, distribution_a):
"""Updates Layer's call method"""
kl_batch = self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a])
self.add_metric(
kl_batch, aggregation="mean", name="kl_divergence",
)
# noinspection PyProtectedMember
self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
return distribution_a
......@@ -323,6 +331,8 @@ class MMDiscrepancyLayer(Layer):
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def get_config(self):
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"beta": self.beta})
......@@ -330,8 +340,10 @@ class MMDiscrepancyLayer(Layer):
return config
def call(self, z, **kwargs):
"""Updates Layer's call method"""
true_samples = self.prior.sample(self.batch_size)
mmd_batch = self.beta * compute_mmd([true_samples, z])
mmd_batch = self.beta * compute_mmd((true_samples, z))
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
......
......@@ -202,7 +202,7 @@ class SEQ_2_SEQ_GMVAE:
if self.prior == "standard_normal":
init_means = far_away_uniform_initialiser(
shape=[self.number_of_components, self.ENCODING], minval=0, maxval=5
shape=(self.number_of_components, self.ENCODING), minval=0, maxval=5
)
self.prior = tfd.mixture.Mixture(
......
......@@ -15,12 +15,15 @@ from hypothesis.extra.numpy import arrays
import deepof.model_utils
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.framework.ops import EagerTensor
# For coverage.py to work with @tf.function decorated functions and methods,
# graph execution is disabled when running this script with pytest
tf.config.experimental_run_functions_eagerly(True)
tfpl = tfp.layers
tfd = tfp.distributions
@settings(deadline=None)
......@@ -150,14 +153,66 @@ def test_dense_transpose():
assert type(fit) == tf.python.keras.callbacks.History
# def test_KLDivergenceLayer():
# pass
#
#
# @settings(deadline=None)
# @given()
# def test_mmdiscrepancy_layer():
# pass
def test_KLDivergenceLayer():
X = tf.random.uniform([1500, 10], 0, 10)
y = np.random.randint(0, 2, [1500, 1])
prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(10), scale=1,), reinterpreted_batch_ndims=1,
)
dense_1 = tf.keras.layers.Dense(10)
i = tf.keras.layers.Input(shape=(10,))
d = dense_1(i)
x = tfpl.DistributionLambda(
lambda dense: tfd.Independent(
tfd.Normal(loc=dense, scale=1,), reinterpreted_batch_ndims=1,
)
)(d)
x = deepof.model_utils.KLDivergenceLayer(
prior, weight=tf.keras.backend.variable(1.0, name="kl_beta")
)(x)
test_model = tf.keras.Model(i, x)
test_model.compile(
loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(),
)
fit = test_model.fit(X, y, epochs=10, batch_size=100)
assert type(fit) == tf.python.keras.callbacks.History
def test_MMDiscrepancyLayer():
X = tf.random.uniform([1500, 10], 0, 10)
y = np.random.randint(0, 2, [1500, 1])
prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(10), scale=1, ), reinterpreted_batch_ndims=1,
)
dense_1 = tf.keras.layers.Dense(10)
i = tf.keras.layers.Input(shape=(10,))
d = dense_1(i)
x = tfpl.DistributionLambda(
lambda dense: tfd.Independent(
tfd.Normal(loc=dense, scale=1, ), reinterpreted_batch_ndims=1,
)
)(d)
x = deepof.model_utils.MMDiscrepancyLayer(
100, prior, beta=tf.keras.backend.variable(1.0, name="kl_beta")
)(x)
test_model = tf.keras.Model(i, x)
test_model.compile(
loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(),
)
fit = test_model.fit(X, y, epochs=10, batch_size=100)
assert type(fit) == tf.python.keras.callbacks.History
#
#
# @settings(deadline=None)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment