Commit cae7ee68 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for model_utils.py

parent 27b68380
......@@ -351,7 +351,7 @@ class MMDiscrepancyLayer(Layer):
return z
class Gaussian_mixture_overlap(Layer):
class Gaussian_mixture_overlap(Layer): # pragma: no cover
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using a specified metric (MMD, Wasserstein, Fischer-Rao)
......@@ -365,6 +365,8 @@ class Gaussian_mixture_overlap(Layer):
super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
def get_config(self):
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update({"lat_dims": self.lat_dims})
config.update({"n_components": self.n_components})
......@@ -372,12 +374,14 @@ class Gaussian_mixture_overlap(Layer):
config.update({"samples": self.samples})
return config
def call(self, target, loss=False):
@tf.function
def call(self, target, **kwargs):
"""Updates Layer's call method"""
dists = []
for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
......@@ -385,7 +389,7 @@ class Gaussian_mixture_overlap(Layer):
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
### MMD-based overlap ###
# MMD-based overlap #
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
......@@ -415,13 +419,15 @@ class Dead_neuron_control(Layer):
def __init__(self, *args, **kwargs):
super(Dead_neuron_control, self).__init__(*args, **kwargs)
def call(self, z, z_gauss, z_cat, **kwargs):
# noinspection PyMethodOverriding
def call(self, target, **kwargs):
"""Updates Layer's call method"""
# Adds metric that monitors dead neurons in the latent space
self.add_metric(
tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
)
return z
return target
class Entropy_regulariser(Layer):
......@@ -429,18 +435,24 @@ class Entropy_regulariser(Layer):
Identity layer that adds cluster weight entropy to the loss function
"""
def __init__(self, weight=1.0, *args, **kwargs):
def __init__(self, weight=1.0, axis=1, *args, **kwargs):
self.weight = weight
self.axis = axis
super(Entropy_regulariser, self).__init__(*args, **kwargs)
def get_config(self):
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update({"weight": self.weight})
config.update({"axis": self.axis})
def call(self, z, **kwargs):
"""Updates Layer's call method"""
# axis=1 increases the entropy of a cluster across instances
# axis=0 increases the entropy of the assignment for a given instance
entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
# Adds metric that monitors dead neurons in the latent space
self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
......
......@@ -335,6 +335,9 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
# Identity layer controlling for dead neurons in the Gaussian Mixture posterior
z_gauss = Dead_neuron_control()(z_gauss)
if self.overlap_loss:
z_gauss = Gaussian_mixture_overlap(
self.ENCODING, self.number_of_components, loss=self.overlap_loss,
......@@ -387,9 +390,6 @@ class SEQ_2_SEQ_GMVAE:
batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
)(z)
# Identity layer controlling clustering and latent space statistics
z = Dead_neuron_control()(z, z_gauss, z_cat)
# Define and instantiate generator
generator = Model_D1(z)
generator = Model_B1(generator)
......
......@@ -188,7 +188,7 @@ def test_MMDiscrepancyLayer():
y = np.random.randint(0, 2, [1500, 1])
prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(10), scale=1, ), reinterpreted_batch_ndims=1,
tfd.Normal(loc=tf.zeros(10), scale=1,), reinterpreted_batch_ndims=1,
)
dense_1 = tf.keras.layers.Dense(10)
......@@ -197,9 +197,10 @@ def test_MMDiscrepancyLayer():
d = dense_1(i)
x = tfpl.DistributionLambda(
lambda dense: tfd.Independent(
tfd.Normal(loc=dense, scale=1, ), reinterpreted_batch_ndims=1,
tfd.Normal(loc=dense, scale=1,), reinterpreted_batch_ndims=1,
)
)(d)
x = deepof.model_utils.MMDiscrepancyLayer(
100, prior, beta=tf.keras.backend.variable(1.0, name="kl_beta")
)(x)
......@@ -213,21 +214,21 @@ def test_MMDiscrepancyLayer():
assert type(fit) == tf.python.keras.callbacks.History
#
#
# @settings(deadline=None)
# @given()
# def test_gaussian_mixture_overlap():
# pass
#
#
# @settings(deadline=None)
# @given()
# def test_dead_neuron_control():
# pass
#
#
# @settings(deadline=None)
# @given()
def test_dead_neuron_control():
X = np.random.uniform(0, 10, [1500, 5])
y = np.random.randint(0, 2, [1500, 1])
test_model = tf.keras.Sequential()
test_model.add(tf.keras.layers.Dense(1))
test_model.add(deepof.model_utils.Dead_neuron_control())
test_model.compile(
loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(),
)
fit = test_model.fit(X, y, epochs=10, batch_size=100)
assert type(fit) == tf.python.keras.callbacks.History
# def test_entropy_regulariser():
# pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment