From 8af352bed0cf0afd3d364dd7a68dd6385e4a9e02 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Thu, 25 Feb 2021 14:13:40 +0100
Subject: [PATCH] Added latent regularization control to
 deepof.data.coordinates.deep_unsupervised_embedding()

---
 deepof/train_model.py |  2 +-
 deepof/train_utils.py | 10 +++++++---
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/deepof/train_model.py b/deepof/train_model.py
index 7702c1d5..12210272 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -411,7 +411,7 @@ else:
                 start_epoch=max(kl_wu, mmd_wu),
             ),
         ],
-        n_replicas=3,
+        n_replicas=1,
         n_epochs=30,
         outpath=output_path,
     )
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 0aab06b7..081911b3 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -545,7 +545,9 @@ def tune_search(
 
     if hpt_type == "hyperband":
         tuner = Hyperband(
-            directory=os.path.join(outpath, "HyperBandx_{}_{}".format(loss, str(date.today()))),
+            directory=os.path.join(
+                outpath, "HyperBandx_{}_{}".format(loss, str(date.today()))
+            ),
             max_epochs=35,
             hyperband_iterations=hypertun_trials,
             factor=2,
@@ -553,7 +555,9 @@ def tune_search(
         )
     else:
         tuner = BayesianOptimization(
-            directory=os.path.join(outpath, "BayOpt_{}_{}".format(loss, str(date.today()))),
+            directory=os.path.join(
+                outpath, "BayOpt_{}_{}".format(loss, str(date.today()))
+            ),
             max_trials=hypertun_trials,
             **hpt_params
         )
@@ -577,7 +581,7 @@ def tune_search(
         epochs=n_epochs,
         validation_data=(Xvals, yvals),
         verbose=1,
-        batch_size=32,
+        batch_size=64,
         callbacks=callbacks,
     )
 
-- 
GitLab