diff --git a/main.ipynb b/main.ipynb index 9197697d6ebdfc603af81226cb8fff70fd038a91..687ad9aeed2228a4a547f89e27ed6f6e49f4912a 100644 --- a/main.ipynb +++ b/main.ipynb @@ -371,7 +371,7 @@ "metadata": {}, "outputs": [], "source": [ - "# vaep.summary()" + "vaep.summary()" ] }, { diff --git a/source/model_utils.py b/source/model_utils.py index 93a23534feccf84ecd46862e4683419caa02f673..5dfd7058c7a94323cee086e71f8b870fe18e24cc 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -109,13 +109,37 @@ class UncorrelatedFeaturesConstraint(Constraint): return self.weightage * self.uncorrelated_feature(x) -class GaussianMixtureLayer(Layer): - def __init(self, *args, **kwargs): - self.is_placeholder = True - super(GaussianMixtureLayer, self).__init__(*args, **kwargs) +class MultivariateNormalDiag(tfpl.DistributionLambda): + def __init__( + self, + event_size, + convert_to_tensor_fn=tfd.Distribution.sample, + validate_args=False, + **kwargs + ): + + super(MultivariateNormalDiag, self).__init__( + lambda t: MultivariateNormalDiag.new(t, event_size, validate_args), + convert_to_tensor_fn, + **kwargs + ) - def call(self, inputs, **kwargs): - pass + @staticmethod + def new(params, event_size, validate_args=False, name=None): + """Create the distribution instance from a `params` vector.""" + with tf.name_scope(name or "MultivariateNormalDiag"): + params = tf.convert_to_tensor(params, name="params") + return tfd.mvn_diag.MultivariateNormalDiag( + loc=params[..., :event_size], + scale_diag=params[..., event_size:], + validate_args=validate_args, + ) + + @staticmethod + def params_size(event_size, name=None): + """The number of `params` needed to create a single distribution.""" + with tf.name_scope(name or "MultivariateNormalDiag_params_size"): + return 2 * event_size class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): diff --git a/source/models.py b/source/models.py index a2bda98228b422ce62c7dc6a64db57454f8e50d1..691098cb5e58c7361193657d2a76d47b959d1a8d 100644 --- a/source/models.py +++ b/source/models.py @@ -282,7 +282,7 @@ class SEQ_2_SEQ_VAE: encoder = Model_E5(encoder) encoder = Dense( - tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None + MultivariateNormalDiag.params_size(self.ENCODING), activation=None )(encoder) # Define and control custom loss functions @@ -299,7 +299,7 @@ class SEQ_2_SEQ_VAE: ) ) - z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder) + z = MultivariateNormalDiag(self.ENCODING)(encoder) if "ELBO" in self.loss: z = KLDivergenceLayer(self.prior, weight=kl_beta)(z) @@ -495,7 +495,7 @@ class SEQ_2_SEQ_VAEP: encoder = Model_E5(encoder) encoder = Dense( - tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None + MultivariateNormalDiag.params_size(self.ENCODING), activation=None )(encoder) # Define and control custom loss functions @@ -511,7 +511,7 @@ class SEQ_2_SEQ_VAEP: ) ) - z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder) + z = MultivariateNormalDiag(self.ENCODING)(encoder) if "ELBO" in self.loss: z = KLDivergenceLayer(self.prior, weight=kl_beta)(z) @@ -585,7 +585,7 @@ class SEQ_2_SEQ_VAEP: # end-to-end autoencoder encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder") vaep = Model( - inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE" + inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAEP" ) # Build generator as a separate entity @@ -883,7 +883,6 @@ class SEQ_2_SEQ_MMVAEP: # TODO: -# - Try sample, mean and mode for MMDiscrepancyLayer # - Gaussian Mixture + Categorical priors -> Deep Clustering # - prior of equal gaussians # - prior of equal gaussians + gaussian noise on the means (not exactly the same init)