diff --git a/main.ipynb b/main.ipynb index f67056da88367216c6fead78137266ca30eecf95..a935d5744e6f1e67b9ef4486951507f562dedf54 100644 --- a/main.ipynb +++ b/main.ipynb @@ -414,9 +414,9 @@ "outputs": [], "source": [ "#tf.config.experimental_run_functions_eagerly(False)\n", - "#history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n", - "# validation_data=(pttest[:-1], pttest[:-1]),\n", - "# callbacks=[tensorboard_callback, kl_wu, mmd_wu])" + "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=100, batch_size=512, verbose=1,\n", + " validation_data=(pttest[:-1], pttest[:-1]),\n", + " callbacks=[tensorboard_callback, kl_wu, mmd_wu])" ] }, { @@ -428,9 +428,9 @@ "outputs": [], "source": [ "#tf.config.experimental_run_functions_eagerly(False)\n", - "history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n", - " validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n", - " callbacks=[tensorboard_callback, kl_wu, mmd_wu])" + "#history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=100, batch_size=512, verbose=1,\n", + "# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n", + "# callbacks=[tensorboard_callback, kl_wu, mmd_wu])" ] }, { diff --git a/source/model_utils.py b/source/model_utils.py index da54e949cce234fa14152a7b0fd127908288d492..da70ef3d7ff042c4c32e46fa7fc5155f5924f080 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -129,11 +129,8 @@ class KLDivergenceLayer(Layer): return config def call(self, inputs, **kwargs): - mu, log_var = inputs - kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) - self.add_loss(self.beta * K.mean(kL_batch), inputs=inputs) return inputs diff --git a/source/models.py b/source/models.py index bbc89f3ac87e3be18582279d71f6983595a19b90..93d39c7feebcb5fc1885e28f9f9ebe5e40a28248 100644 --- a/source/models.py +++ b/source/models.py @@ -280,9 +280,7 @@ class SEQ_2_SEQ_VAE: if self.kl_warmup: def klwarmup(epoch): - value = K.min([epoch / self.kl_warmup, 1]) - print("beta:", value) - kl_beta = value + kl_beta = K.min([epoch / self.kl_warmup, 1]) kl_wu = LambdaCallback(on_epoch_end=lambda epoch, log: klwarmup(epoch)) @@ -297,9 +295,7 @@ class SEQ_2_SEQ_VAE: if self.kl_warmup: def mmdwarmup(epoch): - value = K.min([epoch / self.mmd_warmup, 1]) - print("mmd_beta:", value) - mmd_beta = value + mmd_beta = K.min([epoch / self.mmd_warmup, 1]) mmd_wu = LambdaCallback( on_epoch_end=lambda epoch, log: mmdwarmup(epoch) @@ -488,9 +484,7 @@ class SEQ_2_SEQ_VAEP: if self.kl_warmup: def klwarmup(epoch): - value = K.min([epoch / self.kl_warmup, 1]) - print("beta:", value) - kl_beta = value + kl_beta = K.min([epoch / self.kl_warmup, 1]) kl_wu = LambdaCallback(on_epoch_end=lambda epoch, log: klwarmup(epoch)) @@ -505,9 +499,7 @@ class SEQ_2_SEQ_VAEP: if self.kl_warmup: def mmdwarmup(epoch): - value = K.min([epoch / self.mmd_warmup, 1]) - print("mmd_beta:", value) - mmd_beta = value + mmd_beta = K.min([epoch / self.mmd_warmup, 1]) mmd_wu = LambdaCallback( on_epoch_end=lambda epoch, log: mmdwarmup(epoch)