diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index 87a23b05fbcbaacc53aedeb2b78c8e1b7ca1889a..5ba14df2b11bd52a2b8bb0a6cf7e00358014abd3 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -351,7 +351,7 @@ class MMDiscrepancyLayer(Layer):
         return z
 
 
-class Gaussian_mixture_overlap(Layer):
+class Gaussian_mixture_overlap(Layer):  # pragma: no cover
     """
     Identity layer that measures the overlap between the components of the latent Gaussian Mixture
     using a specified metric (MMD, Wasserstein, Fischer-Rao)
@@ -365,6 +365,8 @@ class Gaussian_mixture_overlap(Layer):
         super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
 
     def get_config(self):
+        """Updates Constraint metadata"""
+
         config = super().get_config().copy()
         config.update({"lat_dims": self.lat_dims})
         config.update({"n_components": self.n_components})
@@ -372,12 +374,14 @@ class Gaussian_mixture_overlap(Layer):
         config.update({"samples": self.samples})
         return config
 
-    def call(self, target, loss=False):
+    @tf.function
+    def call(self, target, **kwargs):
+        """Updates Layer's call method"""
 
         dists = []
         for k in range(self.n_components):
             locs = (target[..., : self.lat_dims, k],)
-            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
+            scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
 
             dists.append(
                 tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
@@ -385,7 +389,7 @@ class Gaussian_mixture_overlap(Layer):
 
         dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
 
-        ### MMD-based overlap ###
+        # MMD-based overlap #
         intercomponent_mmd = K.mean(
             tf.convert_to_tensor(
                 [
@@ -415,13 +419,15 @@ class Dead_neuron_control(Layer):
     def __init__(self, *args, **kwargs):
         super(Dead_neuron_control, self).__init__(*args, **kwargs)
 
-    def call(self, z, z_gauss, z_cat, **kwargs):
+    # noinspection PyMethodOverriding
+    def call(self, target, **kwargs):
+        """Updates Layer's call method"""
         # Adds metric that monitors dead neurons in the latent space
         self.add_metric(
-            tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
+            tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
         )
 
-        return z
+        return target
 
 
 class Entropy_regulariser(Layer):
@@ -429,18 +435,24 @@ class Entropy_regulariser(Layer):
     Identity layer that adds cluster weight entropy to the loss function
     """
 
-    def __init__(self, weight=1.0, *args, **kwargs):
+    def __init__(self, weight=1.0, axis=1, *args, **kwargs):
         self.weight = weight
+        self.axis = axis
         super(Entropy_regulariser, self).__init__(*args, **kwargs)
 
     def get_config(self):
+        """Updates Constraint metadata"""
+
         config = super().get_config().copy()
         config.update({"weight": self.weight})
+        config.update({"axis": self.axis})
 
     def call(self, z, **kwargs):
+        """Updates Layer's call method"""
+
         # axis=1 increases the entropy of a cluster across instances
         # axis=0 increases the entropy of the assignment for a given instance
-        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
+        entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis)
 
         # Adds metric that monitors dead neurons in the latent space
         self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
diff --git a/deepof/models.py b/deepof/models.py
index 29e5b8e6335427d5505599ba606099ce0144c07f..1eb5d7968de801cbb7987b456c4b50bacd670def 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -335,6 +335,9 @@ class SEQ_2_SEQ_GMVAE:
 
         z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
 
+        # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
+        z_gauss = Dead_neuron_control()(z_gauss)
+
         if self.overlap_loss:
             z_gauss = Gaussian_mixture_overlap(
                 self.ENCODING, self.number_of_components, loss=self.overlap_loss,
@@ -387,9 +390,6 @@ class SEQ_2_SEQ_GMVAE:
                 batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
             )(z)
 
-        # Identity layer controlling clustering and latent space statistics
-        z = Dead_neuron_control()(z, z_gauss, z_cat)
-
         # Define and instantiate generator
         generator = Model_D1(z)
         generator = Model_B1(generator)
diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py
index e70115b91a09f90370645f8cb8a8848014bdcad2..b350ab0b4025b4d790f9fa660eec939ce7cf3771 100644
--- a/tests/test_model_utils.py
+++ b/tests/test_model_utils.py
@@ -188,7 +188,7 @@ def test_MMDiscrepancyLayer():
     y = np.random.randint(0, 2, [1500, 1])
 
     prior = tfd.Independent(
-        tfd.Normal(loc=tf.zeros(10), scale=1, ), reinterpreted_batch_ndims=1,
+        tfd.Normal(loc=tf.zeros(10), scale=1,), reinterpreted_batch_ndims=1,
     )
 
     dense_1 = tf.keras.layers.Dense(10)
@@ -197,9 +197,10 @@ def test_MMDiscrepancyLayer():
     d = dense_1(i)
     x = tfpl.DistributionLambda(
         lambda dense: tfd.Independent(
-            tfd.Normal(loc=dense, scale=1, ), reinterpreted_batch_ndims=1,
+            tfd.Normal(loc=dense, scale=1,), reinterpreted_batch_ndims=1,
         )
     )(d)
+
     x = deepof.model_utils.MMDiscrepancyLayer(
         100, prior, beta=tf.keras.backend.variable(1.0, name="kl_beta")
     )(x)
@@ -213,21 +214,21 @@ def test_MMDiscrepancyLayer():
     assert type(fit) == tf.python.keras.callbacks.History
 
 
-#
-#
-# @settings(deadline=None)
-# @given()
-# def test_gaussian_mixture_overlap():
-#     pass
-#
-#
-# @settings(deadline=None)
-# @given()
-# def test_dead_neuron_control():
-#     pass
-#
-#
-# @settings(deadline=None)
-# @given()
+def test_dead_neuron_control():
+    X = np.random.uniform(0, 10, [1500, 5])
+    y = np.random.randint(0, 2, [1500, 1])
+
+    test_model = tf.keras.Sequential()
+    test_model.add(tf.keras.layers.Dense(1))
+    test_model.add(deepof.model_utils.Dead_neuron_control())
+
+    test_model.compile(
+        loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(),
+    )
+
+    fit = test_model.fit(X, y, epochs=10, batch_size=100)
+    assert type(fit) == tf.python.keras.callbacks.History
+
+
 # def test_entropy_regulariser():
 #     pass