diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 036bea1a857671d68e2ce00996b998dcabc4c344..87a23b05fbcbaacc53aedeb2b78c8e1b7ca1889a 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -287,23 +287,31 @@ class DenseTranspose(Layer): class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): + """ + Identity transform layer that adds KL Divergence + to the final model loss. + """ + def __init__(self, *args, **kwargs): self.is_placeholder = True super(KLDivergenceLayer, self).__init__(*args, **kwargs) def get_config(self): + """Updates Constraint metadata""" + config = super().get_config().copy() - config.update( - {"is_placeholder": self.is_placeholder,} - ) + config.update({"is_placeholder": self.is_placeholder}) return config def call(self, distribution_a): + """Updates Layer's call method""" + kl_batch = self._regularizer(distribution_a) self.add_loss(kl_batch, inputs=[distribution_a]) self.add_metric( kl_batch, aggregation="mean", name="kl_divergence", ) + # noinspection PyProtectedMember self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate") return distribution_a @@ -323,6 +331,8 @@ class MMDiscrepancyLayer(Layer): super(MMDiscrepancyLayer, self).__init__(*args, **kwargs) def get_config(self): + """Updates Constraint metadata""" + config = super().get_config().copy() config.update({"batch_size": self.batch_size}) config.update({"beta": self.beta}) @@ -330,8 +340,10 @@ class MMDiscrepancyLayer(Layer): return config def call(self, z, **kwargs): + """Updates Layer's call method""" + true_samples = self.prior.sample(self.batch_size) - mmd_batch = self.beta * compute_mmd([true_samples, z]) + mmd_batch = self.beta * compute_mmd((true_samples, z)) self.add_loss(K.mean(mmd_batch), inputs=z) self.add_metric(mmd_batch, aggregation="mean", name="mmd") self.add_metric(self.beta, aggregation="mean", name="mmd_rate") diff --git a/deepof/models.py b/deepof/models.py index 6ac1cc6efc2e15c4e0e540492fe6cd8a1b2b98f8..29e5b8e6335427d5505599ba606099ce0144c07f 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -202,7 +202,7 @@ class SEQ_2_SEQ_GMVAE: if self.prior == "standard_normal": init_means = far_away_uniform_initialiser( - shape=[self.number_of_components, self.ENCODING], minval=0, maxval=5 + shape=(self.number_of_components, self.ENCODING), minval=0, maxval=5 ) self.prior = tfd.mixture.Mixture( diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index 512a64523f5da15a3cedf24889d636351b31fcc4..e70115b91a09f90370645f8cb8a8848014bdcad2 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -15,12 +15,15 @@ from hypothesis.extra.numpy import arrays import deepof.model_utils import numpy as np import tensorflow as tf +import tensorflow_probability as tfp from tensorflow.python.framework.ops import EagerTensor # For coverage.py to work with @tf.function decorated functions and methods, # graph execution is disabled when running this script with pytest tf.config.experimental_run_functions_eagerly(True) +tfpl = tfp.layers +tfd = tfp.distributions @settings(deadline=None) @@ -150,14 +153,66 @@ def test_dense_transpose(): assert type(fit) == tf.python.keras.callbacks.History -# def test_KLDivergenceLayer(): -# pass -# -# -# @settings(deadline=None) -# @given() -# def test_mmdiscrepancy_layer(): -# pass +def test_KLDivergenceLayer(): + X = tf.random.uniform([1500, 10], 0, 10) + y = np.random.randint(0, 2, [1500, 1]) + + prior = tfd.Independent( + tfd.Normal(loc=tf.zeros(10), scale=1,), reinterpreted_batch_ndims=1, + ) + + dense_1 = tf.keras.layers.Dense(10) + + i = tf.keras.layers.Input(shape=(10,)) + d = dense_1(i) + x = tfpl.DistributionLambda( + lambda dense: tfd.Independent( + tfd.Normal(loc=dense, scale=1,), reinterpreted_batch_ndims=1, + ) + )(d) + x = deepof.model_utils.KLDivergenceLayer( + prior, weight=tf.keras.backend.variable(1.0, name="kl_beta") + )(x) + test_model = tf.keras.Model(i, x) + + test_model.compile( + loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(), + ) + + fit = test_model.fit(X, y, epochs=10, batch_size=100) + assert type(fit) == tf.python.keras.callbacks.History + + +def test_MMDiscrepancyLayer(): + X = tf.random.uniform([1500, 10], 0, 10) + y = np.random.randint(0, 2, [1500, 1]) + + prior = tfd.Independent( + tfd.Normal(loc=tf.zeros(10), scale=1, ), reinterpreted_batch_ndims=1, + ) + + dense_1 = tf.keras.layers.Dense(10) + + i = tf.keras.layers.Input(shape=(10,)) + d = dense_1(i) + x = tfpl.DistributionLambda( + lambda dense: tfd.Independent( + tfd.Normal(loc=dense, scale=1, ), reinterpreted_batch_ndims=1, + ) + )(d) + x = deepof.model_utils.MMDiscrepancyLayer( + 100, prior, beta=tf.keras.backend.variable(1.0, name="kl_beta") + )(x) + test_model = tf.keras.Model(i, x) + + test_model.compile( + loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(), + ) + + fit = test_model.fit(X, y, epochs=10, batch_size=100) + assert type(fit) == tf.python.keras.callbacks.History + + # # # @settings(deadline=None)