diff --git a/source/model_utils.py b/source/model_utils.py
index 2162f46af653ade60ad5786551e80ca7e2e3f7eb..57a58458f8964360e9250aad62dfc9b44325fd7f 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -125,13 +125,13 @@ class MultivariateNormalDiag(tfpl.DistributionLambda):
         )
 
     @staticmethod
-    def new(params, event_size, validate_args=False, name=None):
+    def new(params, event_size, epsilon=1e-8, validate_args=False, name=None):
         """Create the distribution instance from a `params` vector."""
         with tf.name_scope(name or "MultivariateNormalDiag"):
             params = tf.convert_to_tensor(params, name="params")
         return tfd.mvn_diag.MultivariateNormalDiag(
             loc=params[..., :event_size],
-            scale_diag=params[..., event_size:],
+            scale_diag=params[..., event_size:]+epsilon,
             validate_args=validate_args,
         )