diff --git a/deepof/train_model.py b/deepof/train_model.py
index 4975c33ee8e5a9020ee1aa70d338359f6fd6c99a..c1ff5acad38760d00e325008397798e9c61c70b0 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -490,7 +490,7 @@ if not tune:
             history = gmvaep.fit(
                 x=Xs,
                 y=ys,
-                epochs=35,
+                epochs=1,
                 batch_size=batch_size,
                 verbose=1,
                 validation_data=(
@@ -510,25 +510,24 @@ if not tune:
                 )
             )
 
-            if logparam is not None:
-                # Logparams to tensorboard
-                def run(run_dir, hpms):
-                    with tf.summary.create_file_writer(run_dir).as_default():
-                        hp.hparams(hpms)  # record the values used in this trial
-                        val_mae = tf.reduce_mean(
-                            tf.keras.metrics.mean_absolute_error(
-                                X_val, gmvaep.predict(X_val)
-                            )
+            # Logparams to tensorboard
+            def run(run_dir, hpms):
+                with tf.summary.create_file_writer(run_dir).as_default():
+                    hp.hparams(hpms)  # record the values used in this trial
+                    val_mae = tf.reduce_mean(
+                        tf.keras.metrics.mean_absolute_error(
+                            X_val, gmvaep.predict(X_val)
                         )
-                        val_mse = tf.reduce_mean(
-                            tf.keras.metrics.mean_squared_error(
-                                X_val, gmvaep.predict(X_val)
-                            )
+                    )
+                    val_mse = tf.reduce_mean(
+                        tf.keras.metrics.mean_squared_error(
+                            X_val, gmvaep.predict(X_val)
                         )
-                        tf.summary.scalar("val_mae", val_mae, step=1)
-                        tf.summary.scalar("val_mse", val_mse, step=1)
+                    )
+                    tf.summary.scalar("val_mae", val_mae, step=1)
+                    tf.summary.scalar("val_mse", val_mse, step=1)
 
-                run(os.path.join(output_path, "hparams", run_ID), logparam)
+            run(os.path.join(output_path, "hparams", run_ID), logparam)
 
         # To avoid stability issues
         tf.keras.backend.clear_session()
diff --git a/deepof_experiments.smk b/deepof_experiments.smk
index 80c5a744de3b55ca2246d8e878aa34dfe27f7651..a6aca1538ff76ab6cd9bf77cc93a6174d575aa28 100644
--- a/deepof_experiments.smk
+++ b/deepof_experiments.smk
@@ -73,8 +73,8 @@ rule explore_encoding_dimension_and_loss_function:
         "--predictor 0 "
         "--variational True "
         "--loss {wildcards.loss} "
-        "--kl-warmup 2 "
-        "--mmd-warmup 2 "
+        "--kl-warmup 20 "
+        "--mmd-warmup 20 "
         "--montecarlo-kl 10 "
         "--encoding-size {wildcards.encs} "
         "--batch-size 256 "