From 02ca2b4c15336ceeb0852fb36bc7f537f3f4aabe Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Wed, 27 May 2020 13:43:25 +0200 Subject: [PATCH] Added Batch Normalization to SEQ2SEQ_VAE --- main.ipynb | 8 ++++---- source/models.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/main.ipynb b/main.ipynb index a10e64e4..6c134313 100644 --- a/main.ipynb +++ b/main.ipynb @@ -303,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "NAME = 'Baseline_AE_BatchNorm'\n", + "NAME = 'Baseline_VAE_BatchNorm'\n", "log_dir = os.path.abspath(\n", " \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n", ")\n", @@ -325,7 +325,7 @@ "metadata": {}, "outputs": [], "source": [ - "encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()" + "#encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()" ] }, { @@ -334,7 +334,7 @@ "metadata": {}, "outputs": [], "source": [ - "#encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()" + "encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()" ] }, { @@ -376,7 +376,7 @@ "outputs": [], "source": [ "tf.config.experimental_run_functions_eagerly(False)\n", - "history = ae.fit(pttrain, pttrain, epochs=50, batch_size=256, verbose=1, validation_data=(pttest, pttest),\n", + "history = vae.fit(pttrain, pttrain, epochs=50, batch_size=256, verbose=1, validation_data=(pttest, pttest),\n", " callbacks=[tensorboard_callback])" ] }, diff --git a/source/models.py b/source/models.py index fd46fda7..9d0b97d1 100644 --- a/source/models.py +++ b/source/models.py @@ -304,8 +304,7 @@ class SEQ_2_SEQ_MMVAE: # - Change LSTMs for GRU (done!) # - Tied/Untied weights (done!) # - orthogonal/non-orthogonal weights (done!) -# - Unit Norm constraint - +# - Unit Norm constraint (done!) # - add batch normalization # - add He initialization -- GitLab