From 55ba3a09a30f3204a8b56a424ed739b28d33befc Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Thu, 4 Jun 2020 12:48:27 +0200 Subject: [PATCH] Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py --- main.ipynb | 12 ++++++------ source/model_utils.py | 3 --- source/models.py | 16 ++++------------ 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/main.ipynb b/main.ipynb index f67056da..a935d574 100644 --- a/main.ipynb +++ b/main.ipynb @@ -414,9 +414,9 @@ "outputs": [], "source": [ "#tf.config.experimental_run_functions_eagerly(False)\n", - "#history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n", - "# validation_data=(pttest[:-1], pttest[:-1]),\n", - "# callbacks=[tensorboard_callback, kl_wu, mmd_wu])" + "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n", + " validation_data=(pttest[:-1], pttest[:-1]),\n", + " callbacks=[tensorboard_callback, kl_wu, mmd_wu])" ] }, { @@ -428,9 +428,9 @@ "outputs": [], "source": [ "#tf.config.experimental_run_functions_eagerly(False)\n", - "history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n", - " validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n", - " callbacks=[tensorboard_callback, kl_wu, mmd_wu])" + "#history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n", + "# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n", + "# callbacks=[tensorboard_callback, kl_wu, mmd_wu])" ] }, { diff --git a/source/model_utils.py b/source/model_utils.py index da54e949..da70ef3d 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -129,11 +129,8 @@ class KLDivergenceLayer(Layer): return config def call(self, inputs, **kwargs): - mu, log_var = inputs - kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) - self.add_loss(self.beta * K.mean(kL_batch), inputs=inputs) return inputs diff --git a/source/models.py b/source/models.py index bbc89f3a..93d39c7f 100644 --- a/source/models.py +++ b/source/models.py @@ -280,9 +280,7 @@ class SEQ_2_SEQ_VAE: if self.kl_warmup: def klwarmup(epoch): - value = K.min([epoch / self.kl_warmup, 1]) - print("beta:", value) - kl_beta = value + kl_beta = K.min([epoch / self.kl_warmup, 1]) kl_wu = LambdaCallback(on_epoch_end=lambda epoch, log: klwarmup(epoch)) @@ -297,9 +295,7 @@ class SEQ_2_SEQ_VAE: if self.kl_warmup: def mmdwarmup(epoch): - value = K.min([epoch / self.mmd_warmup, 1]) - print("mmd_beta:", value) - mmd_beta = value + mmd_beta = K.min([epoch / self.mmd_warmup, 1]) mmd_wu = LambdaCallback( on_epoch_end=lambda epoch, log: mmdwarmup(epoch) @@ -488,9 +484,7 @@ class SEQ_2_SEQ_VAEP: if self.kl_warmup: def klwarmup(epoch): - value = K.min([epoch / self.kl_warmup, 1]) - print("beta:", value) - kl_beta = value + kl_beta = K.min([epoch / self.kl_warmup, 1]) kl_wu = LambdaCallback(on_epoch_end=lambda epoch, log: klwarmup(epoch)) @@ -505,9 +499,7 @@ class SEQ_2_SEQ_VAEP: if self.kl_warmup: def mmdwarmup(epoch): - value = K.min([epoch / self.mmd_warmup, 1]) - print("mmd_beta:", value) - mmd_beta = value + mmd_beta = K.min([epoch / self.mmd_warmup, 1]) mmd_wu = LambdaCallback( on_epoch_end=lambda epoch, log: mmdwarmup(epoch) -- GitLab