From bfd1521e502cc58e5bd8833f02dc04583e081ccd Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Fri, 5 Jun 2020 15:10:15 +0200 Subject: [PATCH] Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py --- main.ipynb | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/main.ipynb b/main.ipynb index ce4168b9..ae8d2a25 100644 --- a/main.ipynb +++ b/main.ipynb @@ -295,7 +295,7 @@ "metadata": {}, "outputs": [], "source": [ - "NAME = 'Baseline_VAE_short_512_10=warmup_begin'\n", + "NAME = 'Baseline_VAEP_short_512_10=warmup_begin'\n", "log_dir = os.path.abspath(\n", " \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n", ")\n", @@ -326,10 +326,10 @@ "metadata": {}, "outputs": [], "source": [ - "encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n", - " loss='ELBO+MMD',\n", - " kl_warmup_epochs=10,\n", - " mmd_warmup_epochs=10).build()" + "# encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n", + "# loss='ELBO+MMD',\n", + "# kl_warmup_epochs=10,\n", + "# mmd_warmup_epochs=10).build()" ] }, { @@ -338,10 +338,10 @@ "metadata": {}, "outputs": [], "source": [ - "# encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n", - "# loss='ELBO+MMD',\n", - "# kl_warmup_epochs=10,\n", - "# mmd_warmup_epochs=10).build()" + "encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n", + " loss='ELBO+MMD',\n", + " kl_warmup_epochs=10,\n", + " mmd_warmup_epochs=10).build()" ] }, { @@ -405,10 +405,10 @@ "metadata": {}, "outputs": [], "source": [ - "tf.config.experimental_run_functions_eagerly(False)\n", - "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1,\n", - " validation_data=(pttest[:-1], pttest[:-1]),\n", - " callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])" + "# tf.config.experimental_run_functions_eagerly(False)\n", + "# history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1,\n", + "# validation_data=(pttest[:-1], pttest[:-1]),\n", + "# callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])" ] }, { @@ -419,10 +419,10 @@ }, "outputs": [], "source": [ - "# tf.config.experimental_run_functions_eagerly(False)\n", - "# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1,\n", - "# validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n", - "# callbacks=[tensorboard_callback])" + "tf.config.experimental_run_functions_eagerly(False)\n", + "history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1,\n", + " validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n", + " callbacks=[tensorboard_callback])" ] } ], -- GitLab