diff --git a/main.ipynb b/main.ipynb index b9df2e6a4579c56b1f50a1b09c7d9b29023361c4..d16ec63a2b21fec2cbdfa6a127b03c7a3bb43710 100644 --- a/main.ipynb +++ b/main.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "tags": [ "parameters" @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -60,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -76,9 +76,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.59 s, sys: 818 ms, total: 3.41 s\n", + "Wall time: 1.1 s\n" + ] + } + ], "source": [ "%%time\n", "DLC_social_1 = project(path=path,#Path where to find the required files\n", @@ -106,7 +115,17 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading trajectories...\n", + "Smoothing trajectories...\n", + "Computing distances...\n" + ] + } + ], "source": [ "%%time\n", "DLC_social_1_coords = DLC_social_1.run(verbose=True)\n", @@ -336,11 +355,12 @@ "metadata": {}, "outputs": [], "source": [ + "k.backend.clear_session()\n", "encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n", " loss='ELBO+MMD',\n", " kl_warmup_epochs=10,\n", " mmd_warmup_epochs=10).build()\n", - "vae.build(pttest.shape)" + "#vae.build(pttest.shape)" ] }, { @@ -400,17 +420,6 @@ "plot_model(gmvaep, show_shapes=True)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "?plot_model" - ] - }, { "cell_type": "code", "execution_count": null, @@ -419,7 +428,8 @@ "source": [ "#np.random.shuffle(pttest)\n", "pttrain = pttest[:-15000]\n", - "pttest = pttest[-15000:]" + "pttest = pttest[-15000:]\n", + "pttrain = pttrain[:15000]" ] }, { @@ -439,7 +449,7 @@ "outputs": [], "source": [ "# tf.config.experimental_run_functions_eagerly(False)\n", - "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1,\n", + "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=2, batch_size=512, verbose=1,\n", " validation_data=(pttest[:-1], pttest[:-1]),\n", " callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])" ] diff --git a/source/model_utils.py b/source/model_utils.py index 9114e067dcf230c1783b8061c1994ec47f68668f..32f3112ff5d87b5e6ed25070c1e78a2c59f564b0 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -7,20 +7,9 @@ import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions +tfpl = tfp.layers # Helper functions -def sampling(args, epsilon_std=1.0, number_of_components=1, categorical=None): - z_mean, z_log_sigma = args - - if number_of_components == 1: - epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std) - return z_mean + K.exp(z_log_sigma) * epsilon - - else: - # Implement mixture of gaussians encoding and sampling - pass - - def compute_kernel(x, y): x_size = K.shape(x)[0] y_size = K.shape(y)[0] @@ -120,35 +109,20 @@ class UncorrelatedFeaturesConstraint(Constraint): return self.weightage * self.uncorrelated_feature(x) -class KLDivergenceLayer(Layer): - - """ Identity transform layer that adds KL divergence - to the final model loss. - """ - - def __init__(self, beta=1.0, *args, **kwargs): +class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): + def __init__(self, *args, **kwargs): self.is_placeholder = True - self.beta = beta super(KLDivergenceLayer, self).__init__(*args, **kwargs) - def get_config(self): - config = super().get_config().copy() - config.update({"beta": self.beta}) - return config - - def call(self, inputs, **kwargs): - mu, log_var = inputs - KL_batch = ( - -0.5 - * self.beta - * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) + def call(self, distribution_a): + kl_batch = self._regularizer(distribution_a) + self.add_loss(kl_batch, inputs=[distribution_a]) + self.add_metric( + kl_batch, aggregation="mean", name="kl_divergence", ) + self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate") - self.add_loss(K.mean(KL_batch), inputs=inputs) - self.add_metric(KL_batch, aggregation="mean", name="kl_divergence") - self.add_metric(self.beta, aggregation="mean", name="kl_rate") - - return inputs + return distribution_a class MMDiscrepancyLayer(Layer): @@ -156,20 +130,21 @@ class MMDiscrepancyLayer(Layer): to the final model loss. """ - def __init__(self, beta=1.0, *args, **kwargs): + def __init__(self, prior, beta=1.0, *args, **kwargs): self.is_placeholder = True self.beta = beta + self.prior = prior super(MMDiscrepancyLayer, self).__init__(*args, **kwargs) def get_config(self): config = super().get_config().copy() config.update({"beta": self.beta}) + config.update({"prior": self.prior}) return config def call(self, z, **kwargs): - true_samples = K.random_normal(K.shape(z)) + true_samples = self.prior.sample(1) mmd_batch = self.beta * compute_mmd(true_samples, z) - self.add_loss(K.mean(mmd_batch), inputs=z) self.add_metric(mmd_batch, aggregation="mean", name="mmd") self.add_metric(self.beta, aggregation="mean", name="mmd_rate") diff --git a/source/models.py b/source/models.py index a2c89fe1d5367a25fc12ed0b9ec4b17b86105bd6..9ae61fec31600f08cdb7d9333b24e15977f49cae 100644 --- a/source/models.py +++ b/source/models.py @@ -281,9 +281,6 @@ class SEQ_2_SEQ_VAE: encoder = BatchNormalization()(encoder) encoder = Model_E5(encoder) - # z_mean = Dense(self.ENCODING)(encoder) - # z_log_sigma = Dense(self.ENCODING)(encoder) - encoder = Dense( tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None )(encoder) @@ -302,17 +299,10 @@ class SEQ_2_SEQ_VAE: ) ) - # z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma]) + z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder) - # z = Lambda(sampling)([z_mean, z_log_sigma]) - z = tfpl.MultivariateNormalTriL( - self.ENCODING, - activity_regularizer=( - tfpl.KLDivergenceRegularizer(self.prior, weight=kl_beta) - if "ELBO" in self.loss - else None - ), - )(encoder) + if "ELBO" in self.loss: + z = KLDivergenceLayer(self.prior, weight=kl_beta)(z) mmd_warmup_callback = False if "MMD" in self.loss: @@ -327,7 +317,7 @@ class SEQ_2_SEQ_VAE: ) ) - z = MMDiscrepancyLayer(beta=mmd_beta)(z) + z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z) # Define and instantiate generator generator = Model_D0(z) @@ -388,6 +378,7 @@ class SEQ_2_SEQ_VAEP: loss="ELBO+MMD", kl_warmup_epochs=0, mmd_warmup_epochs=0, + prior="standard_normal", ): self.input_shape = input_shape self.CONV_filters = CONV_filters @@ -399,9 +390,16 @@ class SEQ_2_SEQ_VAEP: self.ENCODING = ENCODING self.learn_rate = learn_rate self.loss = loss + self.prior = prior self.kl_warmup = kl_warmup_epochs self.mmd_warmup = mmd_warmup_epochs + if self.prior == "standard_normal": + self.prior = tfd.Independent( + tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1), + reinterpreted_batch_ndims=1, + ) + assert ( "ELBO" in self.loss or "MMD" in self.loss ), "loss must be one of ELBO, MMD or ELBO+MMD (default)" @@ -496,8 +494,9 @@ class SEQ_2_SEQ_VAEP: encoder = BatchNormalization()(encoder) encoder = Model_E5(encoder) - z_mean = Dense(self.ENCODING)(encoder) - z_log_sigma = Dense(self.ENCODING)(encoder) + encoder = Dense( + tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None + )(encoder) # Define and control custom loss functions kl_warmup_callback = False @@ -512,9 +511,10 @@ class SEQ_2_SEQ_VAEP: ) ) - z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma]) + z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder) - z = Lambda(sampling)([z_mean, z_log_sigma]) + if "ELBO" in self.loss: + z = KLDivergenceLayer(self.prior, weight=kl_beta)(z) mmd_warmup_callback = False if "MMD" in self.loss: @@ -583,7 +583,7 @@ class SEQ_2_SEQ_VAEP: )(predictor) # end-to-end autoencoder - encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder") + encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder") vaep = Model( inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE" ) @@ -629,6 +629,7 @@ class SEQ_2_SEQ_MMVAEP: loss="ELBO+MMD", kl_warmup_epochs=0, mmd_warmup_epochs=0, + prior="standard_normal", number_of_components=1, ): self.input_shape = input_shape @@ -641,13 +642,16 @@ class SEQ_2_SEQ_MMVAEP: self.ENCODING = ENCODING self.learn_rate = learn_rate self.loss = loss + self.prior = prior self.kl_warmup = kl_warmup_epochs self.mmd_warmup = mmd_warmup_epochs self.number_of_components = number_of_components - assert ( - self.number_of_components > 0 - ), "The number of components must be an integer greater than zero" + if self.prior == "standard_normal": + self.prior = tfd.Independent( + tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1), + reinterpreted_batch_ndims=1, + ) assert ( "ELBO" in self.loss or "MMD" in self.loss @@ -743,19 +747,9 @@ class SEQ_2_SEQ_MMVAEP: encoder = BatchNormalization()(encoder) encoder = Model_E5(encoder) - # Categorical prior on mixture of Gaussians - categories = Dense(self.number_of_components, activation="softmax") - - # Define mean and log_sigma as lists of vectors with an item per prior component - z_mean = [] - z_log_sigma = [] - for i in range(self.number_of_components): - z_mean.append( - Dense(self.ENCODING, name="{}_gaussian_mean".format(i + 1))(encoder) - ) - z_log_sigma.append( - Dense(self.ENCODING, name="{}_gaussian_sigma".format(i + 1))(encoder) - ) + encoder = Dense( + tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None + )(encoder) # Define and control custom loss functions kl_warmup_callback = False @@ -770,11 +764,10 @@ class SEQ_2_SEQ_MMVAEP: ) ) - z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)( - [z_mean[0], z_log_sigma[0]] - ) + z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder) - z = Lambda(sampling)([z_mean, z_log_sigma]) + if "ELBO" in self.loss: + z = KLDivergenceLayer(self.prior, weight=kl_beta)(z) mmd_warmup_callback = False if "MMD" in self.loss: @@ -803,7 +796,7 @@ class SEQ_2_SEQ_MMVAEP: generator = Model_D5(generator) generator = Model_B5(generator) x_decoded_mean = TimeDistributed( - Dense(self.input_shape[2]), name="gmvaep_reconstruction" + Dense(self.input_shape[2]), name="vaep_reconstruction" )(generator) # Define and instantiate predictor @@ -839,11 +832,11 @@ class SEQ_2_SEQ_MMVAEP: )(predictor) predictor = BatchNormalization()(predictor) x_predicted_mean = TimeDistributed( - Dense(self.input_shape[2]), name="gmvaep_prediction" + Dense(self.input_shape[2]), name="vaep_prediction" )(predictor) # end-to-end autoencoder - encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder") + encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder") gmvaep = Model( inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE" )