Commit 27de1948 authored by lucas_miranda's avatar lucas_miranda
Browse files

Changed full model for diagonal model in all variational implementations in models.py

parent 35822d92
......@@ -109,42 +109,6 @@ class UncorrelatedFeaturesConstraint(Constraint):
return self.weightage * self.uncorrelated_feature(x)
class MultivariateNormalDiag(tfpl.DistributionLambda):
def __init__(
self,
event_size,
convert_to_tensor_fn=tfd.Distribution.sample,
validate_args=True,
**kwargs
):
super(MultivariateNormalDiag, self).__init__(
lambda t: MultivariateNormalDiag.new(t, event_size, validate_args),
convert_to_tensor_fn,
**kwargs
)
@staticmethod
def new(params, event_size, epsilon=1e-8, validate_args=False, name=None):
"""Create the distribution instance from a `params` vector."""
with tf.name_scope(name or "MultivariateNormalDiag"):
params = tf.convert_to_tensor(params, name="params")
return tfd.Independent(
tfd.mvn_diag.MultivariateNormalDiag(
loc=params[..., :event_size],
scale_diag=params[..., event_size:] + epsilon,
validate_args=validate_args,
),
reinterpreted_batch_ndims=1,
)
@staticmethod
def params_size(event_size, name=None):
"""The number of `params` needed to create a single distribution."""
with tf.name_scope(name or "MultivariateNormalDiag_params_size"):
return 2 * event_size
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def __init__(self, *args, **kwargs):
self.is_placeholder = True
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment