From c1b69e4862f4e7241bbcd6b71dc56b9f0c1e48c7 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Mon, 7 Dec 2020 23:47:39 +0100
Subject: [PATCH] Added CustomStopper class to train_utils.py, to start early
 stopping only after annealing is over

---
 deepof/train_model.py  | 33 ++++++++++++++++-----------------
 deepof_experiments.smk |  4 ++--
 2 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/deepof/train_model.py b/deepof/train_model.py
index 4975c33e..c1ff5aca 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -490,7 +490,7 @@ if not tune:
             history = gmvaep.fit(
                 x=Xs,
                 y=ys,
-                epochs=35,
+                epochs=1,
                 batch_size=batch_size,
                 verbose=1,
                 validation_data=(
@@ -510,25 +510,24 @@ if not tune:
                 )
             )
 
-            if logparam is not None:
-                # Logparams to tensorboard
-                def run(run_dir, hpms):
-                    with tf.summary.create_file_writer(run_dir).as_default():
-                        hp.hparams(hpms)  # record the values used in this trial
-                        val_mae = tf.reduce_mean(
-                            tf.keras.metrics.mean_absolute_error(
-                                X_val, gmvaep.predict(X_val)
-                            )
+            # Logparams to tensorboard
+            def run(run_dir, hpms):
+                with tf.summary.create_file_writer(run_dir).as_default():
+                    hp.hparams(hpms)  # record the values used in this trial
+                    val_mae = tf.reduce_mean(
+                        tf.keras.metrics.mean_absolute_error(
+                            X_val, gmvaep.predict(X_val)
                         )
-                        val_mse = tf.reduce_mean(
-                            tf.keras.metrics.mean_squared_error(
-                                X_val, gmvaep.predict(X_val)
-                            )
+                    )
+                    val_mse = tf.reduce_mean(
+                        tf.keras.metrics.mean_squared_error(
+                            X_val, gmvaep.predict(X_val)
                         )
-                        tf.summary.scalar("val_mae", val_mae, step=1)
-                        tf.summary.scalar("val_mse", val_mse, step=1)
+                    )
+                    tf.summary.scalar("val_mae", val_mae, step=1)
+                    tf.summary.scalar("val_mse", val_mse, step=1)
 
-                run(os.path.join(output_path, "hparams", run_ID), logparam)
+            run(os.path.join(output_path, "hparams", run_ID), logparam)
 
         # To avoid stability issues
         tf.keras.backend.clear_session()
diff --git a/deepof_experiments.smk b/deepof_experiments.smk
index 80c5a744..a6aca153 100644
--- a/deepof_experiments.smk
+++ b/deepof_experiments.smk
@@ -73,8 +73,8 @@ rule explore_encoding_dimension_and_loss_function:
         "--predictor 0 "
         "--variational True "
         "--loss {wildcards.loss} "
-        "--kl-warmup 2 "
-        "--mmd-warmup 2 "
+        "--kl-warmup 20 "
+        "--mmd-warmup 20 "
         "--montecarlo-kl 10 "
         "--encoding-size {wildcards.encs} "
         "--batch-size 256 "
-- 
GitLab