Skip to content
Snippets Groups Projects
Commit 0c804556 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Changed full model for diagonal model in all variational implementations in models.py

parent 52f4351f
No related branches found
No related tags found
No related merge requests found
......@@ -125,13 +125,13 @@ class MultivariateNormalDiag(tfpl.DistributionLambda):
)
@staticmethod
def new(params, event_size, validate_args=False, name=None):
def new(params, event_size, epsilon=1e-8, validate_args=False, name=None):
"""Create the distribution instance from a `params` vector."""
with tf.name_scope(name or "MultivariateNormalDiag"):
params = tf.convert_to_tensor(params, name="params")
return tfd.mvn_diag.MultivariateNormalDiag(
loc=params[..., :event_size],
scale_diag=params[..., event_size:],
scale_diag=params[..., event_size:]+epsilon,
validate_args=validate_args,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment