diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index bd484f1ca1b80bab767d9c4be4f35c1a739eeaf5..fa3acc9eec484b31867a5c5abdf41e42023375a4 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -138,7 +138,7 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
 
 
 @tf.function
-def compute_mmd(tensors: Tuple[Any, Any]) -> tf.Tensor:
+def compute_mmd(tensors: Tuple[Any]) -> tf.Tensor:
     """
 
     Computes the MMD between the two specified vectors using a gaussian kernel.
diff --git a/deepof/models.py b/deepof/models.py
index 896fa3f18daff2f2722370de105c9b812b8566af..18dfb5584a7b34e19a62f7c817c43a6d1d6ca2ed 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -594,7 +594,7 @@ class SEQ_2_SEQ_GMVAE:
                     tfd.Independent(
                         tfd.Normal(
                             loc=gauss[1][..., : self.ENCODING, k],
-                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
+                            scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
                         ),
                         reinterpreted_batch_ndims=1,
                     )
@@ -744,4 +744,4 @@ class SEQ_2_SEQ_GMVAE:
 
 # TODO:
 #       - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
-#       - Investigate full covariance matrix approximation for the latent space! :)
\ No newline at end of file
+#       - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
diff --git a/deepof/train_model.py b/deepof/train_model.py
index 61702e68ec2091a920eeaa80aecc95c92b193e11..67d3519d47fd9349c400ff51ee5197fae241c4d6 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -349,7 +349,7 @@ if not tune:
             predictor=predictor,
             loss=loss,
             logparam=logparam,
-            outpath=output_path
+            outpath=output_path,
         )
 
         logparams = [
diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py
index b1d08c8529e67411bc4ffa1ef618b1e327b4b718..7d98cf4b4ca9ccc412f763116ef68a8361473043 100644
--- a/tests/test_model_utils.py
+++ b/tests/test_model_utils.py
@@ -61,6 +61,7 @@ def test_compute_mmd(tensor):
     assert null_kernel == 0
 
 
+# noinspection PyUnresolvedReferences
 def test_one_cycle_scheduler():
     cycle1 = deepof.model_utils.one_cycle_scheduler(
         iterations=5, max_rate=1.0, start_rate=0.1, last_iterations=2, last_rate=0.3
@@ -89,6 +90,7 @@ def test_one_cycle_scheduler():
     assert onecycle.history["lr"][4] > onecycle.history["lr"][-1]
 
 
+# noinspection PyUnresolvedReferences
 def test_uncorrelated_features_constraint():
     X = np.random.uniform(0, 10, [1500, 5])
     y = np.random.randint(0, 2, [1500, 1])
@@ -120,6 +122,7 @@ def test_uncorrelated_features_constraint():
     assert correlations[0] > correlations[1]
 
 
+# noinspection PyUnresolvedReferences
 def test_MCDropout():
     X = np.random.uniform(0, 10, [1500, 5])
     y = np.random.randint(0, 2, [1500, 1])
@@ -137,6 +140,7 @@ def test_MCDropout():
     assert type(fit) == tf.python.keras.callbacks.History
 
 
+# noinspection PyUnresolvedReferences
 def test_dense_transpose():
     X = np.random.uniform(0, 10, [1500, 10])
     y = np.random.randint(0, 2, [1500, 1])
@@ -157,21 +161,22 @@ def test_dense_transpose():
     assert type(fit) == tf.python.keras.callbacks.History
 
 
+# noinspection PyCallingNonCallable,PyUnresolvedReferences
 def test_KLDivergenceLayer():
-    X = tf.random.uniform([1500, 10], 0, 10)
-    y = np.random.randint(0, 2, [1500, 1])
+    X = tf.random.uniform([10, 2], 0, 10)
+    y = np.random.randint(0, 1, [10, 1])
 
     prior = tfd.Independent(
         tfd.Normal(
-            loc=tf.zeros(10),
+            loc=tf.zeros(2),
             scale=1,
         ),
         reinterpreted_batch_ndims=1,
     )
 
-    dense_1 = tf.keras.layers.Dense(10)
+    dense_1 = tf.keras.layers.Dense(2)
 
-    i = tf.keras.layers.Input(shape=(10,))
+    i = tf.keras.layers.Input(shape=(2,))
     d = dense_1(i)
     x = tfpl.DistributionLambda(
         lambda dense: tfd.Independent(
@@ -182,20 +187,25 @@ def test_KLDivergenceLayer():
             reinterpreted_batch_ndims=1,
         )
     )(d)
-    x = deepof.model_utils.KLDivergenceLayer(
-        prior, weight=tf.keras.backend.variable(1.0, name="kl_beta")
+    kl_canon = tfpl.KLDivergenceAddLoss(
+        prior, weight=1.
     )(x)
-    test_model = tf.keras.Model(i, x)
+    kl_deepof = deepof.model_utils.KLDivergenceLayer(
+        prior, weight=1.
+    )(x)
+    test_model = tf.keras.Model(i, [kl_canon, kl_deepof])
 
     test_model.compile(
         loss=tf.keras.losses.binary_crossentropy,
         optimizer=tf.keras.optimizers.SGD(),
     )
 
-    fit = test_model.fit(X, y, epochs=10, batch_size=100)
-    assert type(fit) == tf.python.keras.callbacks.History
+    fit = test_model.fit(X, [y,y], epochs=1, batch_size=100)
+    assert tf.python.keras.callbacks.History == type(fit)
+    assert test_model.losses[0] == test_model.losses[1]
 
 
+# noinspection PyUnresolvedReferences
 def test_MMDiscrepancyLayer():
     X = tf.random.uniform([1500, 10], 0, 10)
     y = np.random.randint(0, 2, [1500, 1])
@@ -233,9 +243,10 @@ def test_MMDiscrepancyLayer():
     )
 
     fit = test_model.fit(X, y, epochs=10, batch_size=100)
-    assert type(fit) == tf.python.keras.callbacks.History
+    assert tf.python.keras.callbacks.History == type(fit)
 
 
+# noinspection PyUnresolvedReferences
 def test_dead_neuron_control():
     X = np.random.uniform(0, 10, [1500, 5])
     y = np.random.randint(0, 2, [1500, 1])
@@ -250,9 +261,10 @@ def test_dead_neuron_control():
     )
 
     fit = test_model.fit(X, y, epochs=10, batch_size=100)
-    assert type(fit) == tf.python.keras.callbacks.History
+    assert tf.python.keras.callbacks.History == type(fit)
 
 
+# noinspection PyUnresolvedReferences
 def test_entropy_regulariser():
     X = np.random.uniform(0, 10, [1500, 5])
     y = np.random.randint(0, 2, [1500, 1])