Commit 0b0ddad3 authored by lucas_miranda's avatar lucas_miranda
Browse files

Changed prior initialization

parent 39be1caa
......@@ -108,13 +108,14 @@ class GMVAE:
),
components_distribution=tfd.MultivariateNormalDiag(
loc=tf.Variable(
Orthogonal()(
he_uniform()(
[self.number_of_components, self.ENCODING],
),
name="prior_means",
),
scale_diag=tfp.util.TransformedVariable(
tf.ones([self.number_of_components, self.ENCODING]),
tf.ones([self.number_of_components, self.ENCODING])
/ self.number_of_components,
tfb.Softplus(),
name="prior_scales",
),
......@@ -397,8 +398,7 @@ class GMVAE:
// 2,
name="cluster_means",
activation=None,
kernel_initializer=Orthogonal(), # An alternative is a constant initializer with a matrix of values
# computed from the labels, we could also initialize the prior this way, and update it every N epochs
kernel_initializer=he_uniform(),
)(encoder)
z_gauss_var = Dense(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment