diff --git a/main.ipynb b/main.ipynb index a10e64e4e926ac28b17e1cf0463b0c69aeaa46ef..6c134313d9939ed9f10f10c82eb2b808fcd738a2 100644 --- a/main.ipynb +++ b/main.ipynb @@ -303,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "NAME = 'Baseline_AE_BatchNorm'\n", + "NAME = 'Baseline_VAE_BatchNorm'\n", "log_dir = os.path.abspath(\n", " \"logs/fit/{}_{}\".format(NAME, datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n", ")\n", @@ -325,7 +325,7 @@ "metadata": {}, "outputs": [], "source": [ - "encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()" + "#encoder, decoder, ae = SEQ_2_SEQ_AE(pttest.shape).build()" ] }, { @@ -334,7 +334,7 @@ "metadata": {}, "outputs": [], "source": [ - "#encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()" + "encoder, generator, vae = SEQ_2_SEQ_VAE(pttest.shape).build()" ] }, { @@ -376,7 +376,7 @@ "outputs": [], "source": [ "tf.config.experimental_run_functions_eagerly(False)\n", - "history = ae.fit(pttrain, pttrain, epochs=50, batch_size=256, verbose=1, validation_data=(pttest, pttest),\n", + "history = vae.fit(pttrain, pttrain, epochs=50, batch_size=256, verbose=1, validation_data=(pttest, pttest),\n", " callbacks=[tensorboard_callback])" ] }, diff --git a/source/models.py b/source/models.py index fd46fda717866636b068b4280a275215556c7c37..9d0b97d1b774acc931d7e92fa02793f459b86178 100644 --- a/source/models.py +++ b/source/models.py @@ -304,8 +304,7 @@ class SEQ_2_SEQ_MMVAE: # - Change LSTMs for GRU (done!) # - Tied/Untied weights (done!) # - orthogonal/non-orthogonal weights (done!) -# - Unit Norm constraint - +# - Unit Norm constraint (done!) # - add batch normalization # - add He initialization