Commit b99b3ace authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented version of SEQ_2_SEQ VAE based on tensorflow_probability

parent b1862ad9
This diff is collapsed.
......@@ -4,13 +4,21 @@ from keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
# Helper functions
def sampling(args, epsilon_std=1.0):
def sampling(args, epsilon_std=1.0, number_of_components=1, categorical=None):
z_mean, z_log_sigma = args
epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std)
return z_mean + K.exp(z_log_sigma) * epsilon
if number_of_components == 1:
epsilon = K.random_normal(shape=K.shape(z_mean), mean=0.0, stddev=epsilon_std)
return z_mean + K.exp(z_log_sigma) * epsilon
else:
# Implement mixture of gaussians encoding and sampling
pass
def compute_kernel(x, y):
......
......@@ -12,6 +12,10 @@ from tensorflow.keras.losses import Huber
from tensorflow.keras.optimizers import Adam
from source.model_utils import *
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
class SEQ_2_SEQ_AE:
......@@ -161,6 +165,7 @@ class SEQ_2_SEQ_VAE:
loss="ELBO+MMD",
kl_warmup_epochs=0,
mmd_warmup_epochs=0,
prior="standard_normal",
):
self.input_shape = input_shape
self.CONV_filters = CONV_filters
......@@ -172,9 +177,16 @@ class SEQ_2_SEQ_VAE:
self.ENCODING = ENCODING
self.learn_rate = learn_rate
self.loss = loss
self.prior = prior
self.kl_warmup = kl_warmup_epochs
self.mmd_warmup = mmd_warmup_epochs
if self.prior == "standard_normal":
self.prior = tfd.Independent(
tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
reinterpreted_batch_ndims=1,
)
assert (
"ELBO" in self.loss or "MMD" in self.loss
), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
......@@ -269,8 +281,12 @@ class SEQ_2_SEQ_VAE:
encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder)
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
# z_mean = Dense(self.ENCODING)(encoder)
# z_log_sigma = Dense(self.ENCODING)(encoder)
encoder = Dense(
tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
)(encoder)
# Define and control custom loss functions
kl_warmup_callback = False
......@@ -286,9 +302,17 @@ class SEQ_2_SEQ_VAE:
)
)
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
# z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z = Lambda(sampling)([z_mean, z_log_sigma])
# z = Lambda(sampling)([z_mean, z_log_sigma])
z = tfpl.MultivariateNormalTriL(
self.ENCODING,
activity_regularizer=(
tfpl.KLDivergenceRegularizer(self.prior, weight=kl_beta)
if "ELBO" in self.loss
else None
),
)(encoder)
mmd_warmup_callback = False
if "MMD" in self.loss:
......@@ -320,7 +344,7 @@ class SEQ_2_SEQ_VAE:
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
# end-to-end autoencoder
encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
vae = Model(x, x_decoded_mean, name="SEQ_2_SEQ_VAE")
# Build generator as a separate entity
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment