Commit 14bcf40b authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py

parent 0a974291
This diff is collapsed.
......@@ -118,7 +118,7 @@ class KLDivergenceLayer(Layer):
to the final model loss.
"""
def __init__(self, beta=1, *args, **kwargs):
def __init__(self, beta=1.0, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
......@@ -131,7 +131,9 @@ class KLDivergenceLayer(Layer):
def call(self, inputs, **kwargs):
mu, log_var = inputs
kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
self.add_loss(self.beta * K.mean(kL_batch), inputs=inputs)
self.add_metric(self.beta, aggregation="mean", name="kl_rate")
return inputs
......@@ -141,7 +143,7 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def __init__(self, beta=1, *args, **kwargs):
def __init__(self, beta=1.0, *args, **kwargs):
self.is_placeholder = True
self.beta = beta
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
......@@ -156,5 +158,6 @@ class MMDiscrepancyLayer(Layer):
mmd_batch = compute_mmd(true_samples, z)
self.add_loss(self.beta * K.mean(mmd_batch), inputs=z)
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
return z
......@@ -273,32 +273,32 @@ class SEQ_2_SEQ_VAE:
z_log_sigma = Dense(self.ENCODING)(encoder)
# Define and control custom loss functions
kl_wu = False
kl_warmup_callback = False
if "ELBO" in self.loss:
kl_beta = 1
kl_beta = K.variable(1.0, name="kl_beta")
if self.kl_warmup:
def klwarmup(epoch):
kl_beta = K.min([epoch / self.kl_warmup, 1])
kl_wu = LambdaCallback(on_epoch_end=lambda epoch, log: klwarmup(epoch))
kl_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value(
kl_beta, K.min([epoch / self.kl_warmup, 1])
)
)
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z = Lambda(sampling)([z_mean, z_log_sigma])
mmd_wu = False
mmd_warmup_callback = False
if "MMD" in self.loss:
mmd_beta = 1
if self.kl_warmup:
mmd_beta = K.variable(1.0, name="mmd_beta")
if self.mmd_warmup:
def mmdwarmup(epoch):
mmd_beta = K.min([epoch / self.mmd_warmup, 1])
mmd_wu = LambdaCallback(
on_epoch_end=lambda epoch, log: mmdwarmup(epoch)
mmd_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value(
mmd_beta, K.min([epoch / self.mmd_warmup, 1])
)
)
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
......@@ -348,7 +348,7 @@ class SEQ_2_SEQ_VAE:
experimental_run_tf_function=False,
)
return encoder, generator, vae, kl_wu, mmd_wu
return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback
class SEQ_2_SEQ_VAEP:
......@@ -477,34 +477,30 @@ class SEQ_2_SEQ_VAEP:
z_log_sigma = Dense(self.ENCODING)(encoder)
# Define and control custom loss functions
kl_wu = False
kl_warmup_callback = False
if "ELBO" in self.loss:
kl_beta = 1
kl_beta = K.variable(1.0, name="kl_beta")
if self.kl_warmup:
def klwarmup(epoch):
kl_beta = K.min([epoch / self.kl_warmup, 1])
kl_wu = LambdaCallback(
on_epoch_begin=lambda epoch, log: klwarmup(epoch)
kl_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value(
kl_beta, K.min([epoch / self.kl_warmup, 1])
)
)
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z = Lambda(sampling)([z_mean, z_log_sigma])
mmd_wu = False
mmd_warmup_callback = False
if "MMD" in self.loss:
mmd_beta = 1
if self.kl_warmup:
def mmdwarmup(epoch):
mmd_beta = K.min([epoch / self.mmd_warmup, 1])
mmd_wu = LambdaCallback(
on_epoch_begin=lambda epoch, log: mmdwarmup(epoch)
mmd_beta = K.variable(1.0, name="mmd_beta")
if self.mmd_warmup:
mmd_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: K.set_value(
mmd_beta, K.min([epoch / self.mmd_warmup, 1])
)
)
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
......@@ -594,7 +590,7 @@ class SEQ_2_SEQ_VAEP:
experimental_run_tf_function=False,
)
return encoder, generator, vaep, kl_wu, mmd_wu
return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback
class SEQ_2_SEQ_MMVAE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment