Skip to content
Snippets Groups Projects
Commit 61bb9c12 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added latent regularization control to deepof.data.coordinates.deep_unsupervised_embedding()

parent c7bb409c
No related branches found
No related tags found
No related merge requests found
Pipeline #93312 failed
......@@ -287,7 +287,7 @@ project_coords = project(
animal_ids=tuple([animal_id]),
arena="circular",
arena_dims=tuple([arena_dims]),
enable_iterative_imputation=True,
enable_iterative_imputation=False,
exclude_bodyparts=exclude_bodyparts,
exp_conditions=treatment_dict,
path=train_path,
......@@ -359,6 +359,7 @@ if not tune:
trained_models = project_coords.deep_unsupervised_embedding(
(X_train, y_train, X_val, y_val),
epochs=1,
batch_size=batch_size,
encoding_size=encoding_size,
hparams=hparams,
......
......@@ -94,6 +94,8 @@ def get_callbacks(
phenotype_class: float,
predictor: float,
loss: str,
reg_cat_clusters: bool,
reg_cluster_variance: bool,
logparam: dict = None,
outpath: str = ".",
) -> List[Union[Any]]:
......@@ -103,6 +105,14 @@ def get_callbacks(
- cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling"""
latreg = "none"
if reg_cat_clusters and not reg_cluster_variance:
latreg = "categorical"
elif reg_cluster_variance and not reg_cat_clusters:
latreg = "variance"
elif reg_cat_clusters and reg_cluster_variance:
latreg = "categorical+variance"
run_ID = "{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
("_Pred={}".format(predictor) if predictor > 0 and variational else ""),
......@@ -110,6 +120,7 @@ def get_callbacks(
("_loss={}".format(loss) if variational else ""),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
("_k={}".format(logparam["k"]) if logparam is not None else ""),
("_latreg={}".format(latreg)),
(datetime.now().strftime("%Y%m%d-%H%M%S")),
)
......@@ -251,11 +262,11 @@ def autoencoder_fitting(
log_history: bool,
log_hparams: bool,
loss: str,
mmd_warmup,
montecarlo_kl,
n_components,
output_path,
phenotype_class,
mmd_warmup: int,
montecarlo_kl: int,
n_components: int,
output_path: str,
phenotype_class: float,
predictor: float,
pretrained: str,
save_checkpoints: bool,
......@@ -290,6 +301,8 @@ def autoencoder_fitting(
phenotype_class=phenotype_class,
predictor=predictor,
loss=loss,
reg_cat_clusters=reg_cluster_variance,
reg_cluster_variance=reg_cluster_variance,
logparam=logparam,
outpath=output_path,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment