Commit de37354f authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented KL and MMD warmup on SEQ2SEQ_VAE in models.py

parent 37c1c84a
This diff is collapsed.
......@@ -123,6 +123,11 @@ class KLDivergenceLayer(Layer):
self.beta = beta
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"beta": self.beta})
return config
def call(self, inputs, **kwargs):
mu, log_var = inputs
......@@ -144,6 +149,11 @@ class MMDiscrepancyLayer(Layer):
self.beta = beta
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"beta": self.beta})
return config
def call(self, z, **kwargs):
true_samples = K.random_normal(K.shape(z))
mmd_batch = compute_mmd(true_samples, z)
......
# @author lucasmiranda42
from tensorflow.keras import backend as K
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.initializers import he_uniform, Orthogonal
from tensorflow.keras.layers import BatchNormalization, Bidirectional, Dense
......@@ -173,6 +175,10 @@ class SEQ_2_SEQ_VAE:
self.kl_warmup = kl_warmup_epochs
self.mmd_warmup = mmd_warmup_epochs
assert (
"ELBO" in self.loss or "MMD" in self.loss
), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
def build(self):
# Encoder Layers
Model_E0 = tf.keras.layers.Conv1D(
......@@ -266,13 +272,36 @@ class SEQ_2_SEQ_VAE:
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
kl_wu = False
if "ELBO" in self.loss:
z_mean, z_log_sigma = KLDivergenceLayer()([z_mean, z_log_sigma])
kl_beta = 1
if self.kl_warmup:
def klwarmup(epoch):
value = K.min([epoch / self.kl_warmup, 1])
print("beta:", value)
kl_beta = value
kl_wu = LambdaCallback(on_epoch_end=lambda epoch, log: klwarmup(epoch))
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z = Lambda(sampling)([z_mean, z_log_sigma])
mmd_wu = False
if "MMD" in self.loss:
z = MMDiscrepancyLayer()(z)
mmd_beta = 1
if self.kl_warmup:
def mmdwarmup(epoch):
value = K.min([epoch / self.mmd_warmup, 1])
print("mmd_beta:", value)
mmd_beta = value
mmd_wu = LambdaCallback(on_epoch_end=lambda epoch, log: mmdwarmup(epoch))
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
# Define and instantiate generator
generator = Model_D0(z)
......@@ -319,7 +348,7 @@ class SEQ_2_SEQ_VAE:
experimental_run_tf_function=False,
)
return encoder, generator, vae
return encoder, generator, vae, kl_wu, mmd_wu
class SEQ_2_SEQ_VAEP:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment