Commit d2096481 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented version of SEQ_2_SEQ VAE based on tensorflow_probability

parent b99b3ace
......@@ -230,15 +230,16 @@
```
%% Cell type:code id: tags:
``` python
k.backend.clear_session()
encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
loss='ELBO+MMD',
kl_warmup_epochs=10,
mmd_warmup_epochs=10).build()
vae.build(pttest.shape)
#vae.build(pttest.shape)
```
%% Cell type:code id: tags:
``` python
......@@ -280,19 +281,14 @@
```
%% Cell type:code id: tags:
``` python
?plot_model
```
%% Cell type:code id: tags:
``` python
#np.random.shuffle(pttest)
pttrain = pttest[:-15000]
pttest = pttest[-15000:]
pttrain = pttrain[:15000]
```
%% Cell type:code id: tags:
``` python
......@@ -302,11 +298,11 @@
%% Cell type:code id: tags:
``` python
# tf.config.experimental_run_functions_eagerly(False)
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1,
history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=2, batch_size=512, verbose=1,
validation_data=(pttest[:-1], pttest[:-1]),
callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])
```
%% Cell type:code id: tags:
......
......@@ -7,20 +7,9 @@ import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
# Helper functions
def sampling(args, epsilon_std=1.0, number_of_components=1, categorical=None):
z_mean, z_log_sigma = args
if number_of_components == 1:
epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std)
return z_mean + K.exp(z_log_sigma) * epsilon
else:
# Implement mixture of gaussians encoding and sampling
pass
def compute_kernel(x, y):
x_size = K.shape(x)[0]
y_size = K.shape(y)[0]
......@@ -120,35 +109,20 @@ class UncorrelatedFeaturesConstraint(Constraint):
return self.weightage * self.uncorrelated_feature(x)
class KLDivergenceLayer(Layer):
""" Identity transform layer that adds KL divergence
to the final model loss.
"""
def __init__(self, beta=1.0, *args, **kwargs):
class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def __init__(self, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"beta": self.beta})
return config
def call(self, inputs, **kwargs):
mu, log_var = inputs
KL_batch = (
-0.5
* self.beta
* K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
def call(self, distribution_a):
kl_batch = self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a])
self.add_metric(
kl_batch, aggregation="mean", name="kl_divergence",
)
self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
self.add_loss(K.mean(KL_batch), inputs=inputs)
self.add_metric(KL_batch, aggregation="mean", name="kl_divergence")
self.add_metric(self.beta, aggregation="mean", name="kl_rate")
return inputs
return distribution_a
class MMDiscrepancyLayer(Layer):
......@@ -156,20 +130,21 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def __init__(self, beta=1.0, *args, **kwargs):
def __init__(self, prior, beta=1.0, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
self.prior = prior
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"beta": self.beta})
config.update({"prior": self.prior})
return config
def call(self, z, **kwargs):
true_samples = K.random_normal(K.shape(z))
true_samples = self.prior.sample(1)
mmd_batch = self.beta * compute_mmd(true_samples, z)
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
......
......@@ -281,9 +281,6 @@ class SEQ_2_SEQ_VAE:
encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder)
# z_mean = Dense(self.ENCODING)(encoder)
# z_log_sigma = Dense(self.ENCODING)(encoder)
encoder = Dense(
tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
)(encoder)
......@@ -302,17 +299,10 @@ class SEQ_2_SEQ_VAE:
)
)
# z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
# z = Lambda(sampling)([z_mean, z_log_sigma])
z = tfpl.MultivariateNormalTriL(
self.ENCODING,
activity_regularizer=(
tfpl.KLDivergenceRegularizer(self.prior, weight=kl_beta)
if "ELBO" in self.loss
else None
),
)(encoder)
if "ELBO" in self.loss:
z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
mmd_warmup_callback = False
if "MMD" in self.loss:
......@@ -327,7 +317,7 @@ class SEQ_2_SEQ_VAE:
)
)
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
# Define and instantiate generator
generator = Model_D0(z)
......@@ -388,6 +378,7 @@ class SEQ_2_SEQ_VAEP:
loss="ELBO+MMD",
kl_warmup_epochs=0,
mmd_warmup_epochs=0,
prior="standard_normal",
):
self.input_shape = input_shape
self.CONV_filters = CONV_filters
......@@ -399,9 +390,16 @@ class SEQ_2_SEQ_VAEP:
self.ENCODING = ENCODING
self.learn_rate = learn_rate
self.loss = loss
self.prior = prior
self.kl_warmup = kl_warmup_epochs
self.mmd_warmup = mmd_warmup_epochs
if self.prior == "standard_normal":
self.prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
reinterpreted_batch_ndims=1,
)
assert (
"ELBO" in self.loss or "MMD" in self.loss
), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
......@@ -496,8 +494,9 @@ class SEQ_2_SEQ_VAEP:
encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder)
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
encoder = Dense(
tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
)(encoder)
# Define and control custom loss functions
kl_warmup_callback = False
......@@ -512,9 +511,10 @@ class SEQ_2_SEQ_VAEP:
)
)
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
z = Lambda(sampling)([z_mean, z_log_sigma])
if "ELBO" in self.loss:
z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
mmd_warmup_callback = False
if "MMD" in self.loss:
......@@ -583,7 +583,7 @@ class SEQ_2_SEQ_VAEP:
)(predictor)
# end-to-end autoencoder
encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
vaep = Model(
inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
)
......@@ -629,6 +629,7 @@ class SEQ_2_SEQ_MMVAEP:
loss="ELBO+MMD",
kl_warmup_epochs=0,
mmd_warmup_epochs=0,
prior="standard_normal",
number_of_components=1,
):
self.input_shape = input_shape
......@@ -641,13 +642,16 @@ class SEQ_2_SEQ_MMVAEP:
self.ENCODING = ENCODING
self.learn_rate = learn_rate
self.loss = loss
self.prior = prior
self.kl_warmup = kl_warmup_epochs
self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components
assert (
self.number_of_components > 0
), "The number of components must be an integer greater than zero"
if self.prior == "standard_normal":
self.prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
reinterpreted_batch_ndims=1,
)
assert (
"ELBO" in self.loss or "MMD" in self.loss
......@@ -743,19 +747,9 @@ class SEQ_2_SEQ_MMVAEP:
encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder)
# Categorical prior on mixture of Gaussians
categories = Dense(self.number_of_components, activation="softmax")
# Define mean and log_sigma as lists of vectors with an item per prior component
z_mean = []
z_log_sigma = []
for i in range(self.number_of_components):
z_mean.append(
Dense(self.ENCODING, name="{}_gaussian_mean".format(i + 1))(encoder)
)
z_log_sigma.append(
Dense(self.ENCODING, name="{}_gaussian_sigma".format(i + 1))(encoder)
)
encoder = Dense(
tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
)(encoder)
# Define and control custom loss functions
kl_warmup_callback = False
......@@ -770,11 +764,10 @@ class SEQ_2_SEQ_MMVAEP:
)
)
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)(
[z_mean[0], z_log_sigma[0]]
)
z = tfpl.MultivariateNormalTriL(self.ENCODING)(encoder)
z = Lambda(sampling)([z_mean, z_log_sigma])
if "ELBO" in self.loss:
z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
mmd_warmup_callback = False
if "MMD" in self.loss:
......@@ -803,7 +796,7 @@ class SEQ_2_SEQ_MMVAEP:
generator = Model_D5(generator)
generator = Model_B5(generator)
x_decoded_mean = TimeDistributed(
Dense(self.input_shape[2]), name="gmvaep_reconstruction"
Dense(self.input_shape[2]), name="vaep_reconstruction"
)(generator)
# Define and instantiate predictor
......@@ -839,11 +832,11 @@ class SEQ_2_SEQ_MMVAEP:
)(predictor)
predictor = BatchNormalization()(predictor)
x_predicted_mean = TimeDistributed(
Dense(self.input_shape[2]), name="gmvaep_prediction"
Dense(self.input_shape[2]), name="vaep_prediction"
)(predictor)
# end-to-end autoencoder
encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
gmvaep = Model(
inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment