From bfd1521e502cc58e5bd8833f02dc04583e081ccd Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 5 Jun 2020 15:10:15 +0200
Subject: [PATCH] Implemented KL and MMD warmup on SEQ2SEQ_VAEP in models.py

---
 main.ipynb | 34 +++++++++++++++++-----------------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/main.ipynb b/main.ipynb
index ce4168b9..ae8d2a25 100644
--- a/main.ipynb
+++ b/main.ipynb
@@ -295,7 +295,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "NAME = 'Baseline_VAE_short_512_10=warmup_begin'\n",
+    "NAME = 'Baseline_VAEP_short_512_10=warmup_begin'\n",
     "log_dir = os.path.abspath(\n",
     "    \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
     ")\n",
@@ -326,10 +326,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n",
-    "                                                                   loss='ELBO+MMD',\n",
-    "                                                                   kl_warmup_epochs=10,\n",
-    "                                                                   mmd_warmup_epochs=10).build()"
+    "# encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,\n",
+    "#                                                                    loss='ELBO+MMD',\n",
+    "#                                                                    kl_warmup_epochs=10,\n",
+    "#                                                                    mmd_warmup_epochs=10).build()"
    ]
   },
   {
@@ -338,10 +338,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n",
-    "#                                                                     loss='ELBO+MMD',\n",
-    "#                                                                     kl_warmup_epochs=10,\n",
-    "#                                                                     mmd_warmup_epochs=10).build()"
+    "encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_VAEP(pttest.shape,\n",
+    "                                                                    loss='ELBO+MMD',\n",
+    "                                                                    kl_warmup_epochs=10,\n",
+    "                                                                    mmd_warmup_epochs=10).build()"
    ]
   },
   {
@@ -405,10 +405,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "tf.config.experimental_run_functions_eagerly(False)\n",
-    "history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1,\n",
-    "                  validation_data=(pttest[:-1], pttest[:-1]),\n",
-    "                  callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
+    "# tf.config.experimental_run_functions_eagerly(False)\n",
+    "# history = vae.fit(x=pttrain[:-1], y=pttrain[:-1], epochs=500, batch_size=512, verbose=1,\n",
+    "#                   validation_data=(pttest[:-1], pttest[:-1]),\n",
+    "#                   callbacks=[tensorboard_callback, kl_warmup_callback, mmd_warmup_callback])"
    ]
   },
   {
@@ -419,10 +419,10 @@
    },
    "outputs": [],
    "source": [
-    "# tf.config.experimental_run_functions_eagerly(False)\n",
-    "# history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1,\n",
-    "#                    validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
-    "#                    callbacks=[tensorboard_callback])"
+    "tf.config.experimental_run_functions_eagerly(False)\n",
+    "history = vaep.fit(x=pttrain[:-1], y=[pttrain[:-1],pttrain[1:]], epochs=500, batch_size=512, verbose=1,\n",
+    "                   validation_data=(pttest[:-1], [pttest[:-1],pttest[1:]]),\n",
+    "                   callbacks=[tensorboard_callback])"
    ]
   }
  ],
-- 
GitLab