diff --git a/source/model_utils.py b/source/model_utils.py
index 72b354acba141747c72c07cf309d57bf6c5f3793..270c909818c8dec7649746690d11d4e92fcd4674 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -286,7 +286,7 @@ class Entropy_regulariser(Layer):
     def call(self, z, **kwargs):
 
         entropy = K.sum(
-            tf.multiply(z, tf.where(~tf.math.is_inf(K.log(z)), K.log(z), 0)), axis=1
+            tf.multiply(z, tf.where(~tf.math.is_inf(K.log(z)), K.log(z), 0)), axis=0
         )
 
         # Adds metric that monitors dead neurons in the latent space