diff --git a/source/model_utils.py b/source/model_utils.py index 72b354acba141747c72c07cf309d57bf6c5f3793..270c909818c8dec7649746690d11d4e92fcd4674 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -286,7 +286,7 @@ class Entropy_regulariser(Layer): def call(self, z, **kwargs): entropy = K.sum( - tf.multiply(z, tf.where(~tf.math.is_inf(K.log(z)), K.log(z), 0)), axis=1 + tf.multiply(z, tf.where(~tf.math.is_inf(K.log(z)), K.log(z), 0)), axis=0 ) # Adds metric that monitors dead neurons in the latent space