Commit 942674c7 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py

parent de37354f
This diff is collapsed.
......@@ -272,6 +272,7 @@ class SEQ_2_SEQ_VAE:
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
# Define and control custom loss functions
kl_wu = False
if "ELBO" in self.loss:
......@@ -294,12 +295,15 @@ class SEQ_2_SEQ_VAE:
mmd_beta = 1
if self.kl_warmup:
def mmdwarmup(epoch):
value = K.min([epoch / self.mmd_warmup, 1])
print("mmd_beta:", value)
mmd_beta = value
mmd_wu = LambdaCallback(on_epoch_end=lambda epoch, log: mmdwarmup(epoch))
mmd_wu = LambdaCallback(
on_epoch_end=lambda epoch, log: mmdwarmup(epoch)
)
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
......@@ -472,13 +476,40 @@ class SEQ_2_SEQ_VAEP:
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
# Define and control custom loss functions
kl_wu = False
if "ELBO" in self.loss:
z_mean, z_log_sigma = KLDivergenceLayer()([z_mean, z_log_sigma])
kl_beta = 1
if self.kl_warmup:
def klwarmup(epoch):
value = K.min([epoch / self.kl_warmup, 1])
print("beta:", value)
kl_beta = value
kl_wu = LambdaCallback(on_epoch_end=lambda epoch, log: klwarmup(epoch))
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z = Lambda(sampling)([z_mean, z_log_sigma])
mmd_wu = False
if "MMD" in self.loss:
z = MMDiscrepancyLayer()(z)
mmd_beta = 1
if self.kl_warmup:
def mmdwarmup(epoch):
value = K.min([epoch / self.mmd_warmup, 1])
print("mmd_beta:", value)
mmd_beta = value
mmd_wu = LambdaCallback(
on_epoch_end=lambda epoch, log: mmdwarmup(epoch)
)
z = MMDiscrepancyLayer(beta=mmd_beta)(z)
# Define and instantiate generator
generator = Model_D0(z)
......@@ -492,7 +523,9 @@ class SEQ_2_SEQ_VAEP:
generator = Model_B4(generator)
generator = Model_D5(generator)
generator = Model_B5(generator)
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
x_decoded_mean = TimeDistributed(
Dense(self.input_shape[2]), name="vaep_reconstruction"
)(generator)
# Define and instantiate predictor
predictor = Dense(
......@@ -526,7 +559,9 @@ class SEQ_2_SEQ_VAEP:
)
)(predictor)
predictor = BatchNormalization()(predictor)
x_predicted_mean = TimeDistributed(Dense(self.input_shape[2]))(predictor)
x_predicted_mean = TimeDistributed(
Dense(self.input_shape[2]), name="vaep_prediction"
)(predictor)
# end-to-end autoencoder
encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
......@@ -561,7 +596,7 @@ class SEQ_2_SEQ_VAEP:
experimental_run_tf_function=False,
)
return encoder, generator, vaep
return encoder, generator, vaep, kl_wu, mmd_wu
class SEQ_2_SEQ_MMVAE:
......@@ -569,8 +604,6 @@ class SEQ_2_SEQ_MMVAE:
# TODO:
# - Add learning rate scheduler callback
# - KL / MMD warmup (Ladder Variational Autoencoders)
# - Gaussian Mixture + Categorical priors -> Deep Clustering
# - free bits paper
# - Attention mechanism for encoder / decoder (does it make sense?)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment