From 61bb9c12ba7b9e09bc0d0898218223cc1f33e959 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 12 Feb 2021 00:48:39 +0100
Subject: [PATCH] Added latent regularization control to
 deepof.data.coordinates.deep_unsupervised_embedding()

---
 deepof/train_model.py |  3 ++-
 deepof/train_utils.py | 23 ++++++++++++++++++-----
 2 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/deepof/train_model.py b/deepof/train_model.py
index 638f21c6..7380650c 100644
--- a/deepof/train_model.py
+++ b/deepof/train_model.py
@@ -287,7 +287,7 @@ project_coords = project(
     animal_ids=tuple([animal_id]),
     arena="circular",
     arena_dims=tuple([arena_dims]),
-    enable_iterative_imputation=True,
+    enable_iterative_imputation=False,
     exclude_bodyparts=exclude_bodyparts,
     exp_conditions=treatment_dict,
     path=train_path,
@@ -359,6 +359,7 @@ if not tune:
 
     trained_models = project_coords.deep_unsupervised_embedding(
         (X_train, y_train, X_val, y_val),
+        epochs=1,
         batch_size=batch_size,
         encoding_size=encoding_size,
         hparams=hparams,
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 11ea712f..4b67d37e 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -94,6 +94,8 @@ def get_callbacks(
     phenotype_class: float,
     predictor: float,
     loss: str,
+    reg_cat_clusters: bool,
+    reg_cluster_variance: bool,
     logparam: dict = None,
     outpath: str = ".",
 ) -> List[Union[Any]]:
@@ -103,6 +105,14 @@ def get_callbacks(
     - cp_callback: for checkpoint saving,
     - onecycle: for learning rate scheduling"""
 
+    latreg = "none"
+    if reg_cat_clusters and not reg_cluster_variance:
+        latreg = "categorical"
+    elif reg_cluster_variance and not reg_cat_clusters:
+        latreg = "variance"
+    elif reg_cat_clusters and reg_cluster_variance:
+        latreg = "categorical+variance"
+
     run_ID = "{}{}{}{}{}{}_{}".format(
         ("GMVAE" if variational else "AE"),
         ("_Pred={}".format(predictor) if predictor > 0 and variational else ""),
@@ -110,6 +120,7 @@ def get_callbacks(
         ("_loss={}".format(loss) if variational else ""),
         ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
         ("_k={}".format(logparam["k"]) if logparam is not None else ""),
+        ("_latreg={}".format(latreg)),
         (datetime.now().strftime("%Y%m%d-%H%M%S")),
     )
 
@@ -251,11 +262,11 @@ def autoencoder_fitting(
     log_history: bool,
     log_hparams: bool,
     loss: str,
-    mmd_warmup,
-    montecarlo_kl,
-    n_components,
-    output_path,
-    phenotype_class,
+    mmd_warmup: int,
+    montecarlo_kl: int,
+    n_components: int,
+    output_path: str,
+    phenotype_class: float,
     predictor: float,
     pretrained: str,
     save_checkpoints: bool,
@@ -290,6 +301,8 @@ def autoencoder_fitting(
         phenotype_class=phenotype_class,
         predictor=predictor,
         loss=loss,
+        reg_cat_clusters=reg_cluster_variance,
+        reg_cluster_variance=reg_cluster_variance,
         logparam=logparam,
         outpath=output_path,
     )
-- 
GitLab