From 61bb9c12ba7b9e09bc0d0898218223cc1f33e959 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Fri, 12 Feb 2021 00:48:39 +0100 Subject: [PATCH] Added latent regularization control to deepof.data.coordinates.deep_unsupervised_embedding() --- deepof/train_model.py | 3 ++- deepof/train_utils.py | 23 ++++++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/deepof/train_model.py b/deepof/train_model.py index 638f21c6..7380650c 100644 --- a/deepof/train_model.py +++ b/deepof/train_model.py @@ -287,7 +287,7 @@ project_coords = project( animal_ids=tuple([animal_id]), arena="circular", arena_dims=tuple([arena_dims]), - enable_iterative_imputation=True, + enable_iterative_imputation=False, exclude_bodyparts=exclude_bodyparts, exp_conditions=treatment_dict, path=train_path, @@ -359,6 +359,7 @@ if not tune: trained_models = project_coords.deep_unsupervised_embedding( (X_train, y_train, X_val, y_val), + epochs=1, batch_size=batch_size, encoding_size=encoding_size, hparams=hparams, diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 11ea712f..4b67d37e 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -94,6 +94,8 @@ def get_callbacks( phenotype_class: float, predictor: float, loss: str, + reg_cat_clusters: bool, + reg_cluster_variance: bool, logparam: dict = None, outpath: str = ".", ) -> List[Union[Any]]: @@ -103,6 +105,14 @@ def get_callbacks( - cp_callback: for checkpoint saving, - onecycle: for learning rate scheduling""" + latreg = "none" + if reg_cat_clusters and not reg_cluster_variance: + latreg = "categorical" + elif reg_cluster_variance and not reg_cat_clusters: + latreg = "variance" + elif reg_cat_clusters and reg_cluster_variance: + latreg = "categorical+variance" + run_ID = "{}{}{}{}{}{}_{}".format( ("GMVAE" if variational else "AE"), ("_Pred={}".format(predictor) if predictor > 0 and variational else ""), @@ -110,6 +120,7 @@ def get_callbacks( ("_loss={}".format(loss) if variational else ""), ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""), ("_k={}".format(logparam["k"]) if logparam is not None else ""), + ("_latreg={}".format(latreg)), (datetime.now().strftime("%Y%m%d-%H%M%S")), ) @@ -251,11 +262,11 @@ def autoencoder_fitting( log_history: bool, log_hparams: bool, loss: str, - mmd_warmup, - montecarlo_kl, - n_components, - output_path, - phenotype_class, + mmd_warmup: int, + montecarlo_kl: int, + n_components: int, + output_path: str, + phenotype_class: float, predictor: float, pretrained: str, save_checkpoints: bool, @@ -290,6 +301,8 @@ def autoencoder_fitting( phenotype_class=phenotype_class, predictor=predictor, loss=loss, + reg_cat_clusters=reg_cluster_variance, + reg_cluster_variance=reg_cluster_variance, logparam=logparam, outpath=output_path, ) -- GitLab