Commit 8017c6d2 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent 4e930fb5
Pipeline #98488 canceled with stages
in 25 seconds
......@@ -903,6 +903,7 @@ class coordinates:
reg_cluster_variance: bool = False,
entropy_samples: int = 10000,
entropy_knn: int = 100,
input_type: str = False,
run: int = 0,
) -> Tuple:
"""
......@@ -967,6 +968,7 @@ class coordinates:
reg_cluster_variance=reg_cluster_variance,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
input_type=input_type,
run=run,
)
......
......@@ -403,6 +403,7 @@ if not tune:
reg_cluster_variance=("variance" in latent_reg),
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
input_type=input_type,
run=run,
)
......@@ -413,16 +414,17 @@ else:
run_ID, tensorboard_callback, entropy, onecycle = get_callbacks(
X_train=X_train,
X_val=(X_val if X_val.shape != (0,) else None),
batch_size=batch_size,
cp=False,
variational=variational,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
rule_based_prediction=rule_base_prediction,
loss=loss,
X_val=(X_val if X_val.shape != (0,) else None),
input_type=input_type,
cp=False,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
logparam=logparam,
outpath=output_path,
run=run,
......
......@@ -74,6 +74,7 @@ def get_callbacks(
rule_based_prediction: float,
loss: str,
X_val: np.array = None,
input_type: str = False,
cp: bool = False,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
......@@ -86,8 +87,10 @@ def get_callbacks(
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
- tensorboard_callback: for real-time visualization;
- cp_callback: for checkpoint saving,
- onecycle: for learning rate scheduling"""
- cp_callback: for checkpoint saving;
- onecycle: for learning rate scheduling;
- entropy: neighborhood entropy in the latent space;
"""
latreg = "none"
if reg_cat_clusters and not reg_cluster_variance:
......@@ -99,6 +102,7 @@ def get_callbacks(
run_ID = "{}{}{}{}{}{}{}_{}".format(
("GMVAE" if variational else "AE"),
("_input_type={}".format(input_type) if input_type else "coords"),
("_NextSeqPred={}".format(next_sequence_prediction) if variational else ""),
("_PhenoPred={}".format(phenotype_prediction) if variational else ""),
("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""),
......@@ -293,6 +297,7 @@ def autoencoder_fitting(
reg_cluster_variance: bool,
entropy_samples: int,
entropy_knn: int,
input_type: str,
run: int = 0,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -315,18 +320,19 @@ def autoencoder_fitting(
# Load callbacks
run_ID, *cbacks = get_callbacks(
X_train=X_train,
X_val=(X_val if X_val.shape != (0,) else None),
batch_size=batch_size,
cp=save_checkpoints,
variational=variational,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
input_type=input_type,
X_val=(X_val if X_val.shape != (0,) else None),
cp=save_checkpoints,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
entropy_samples=entropy_samples,
entropy_knn=entropy_knn,
logparam=logparam,
outpath=output_path,
run=run,
......
......@@ -15,7 +15,7 @@ import os
outpath = "/psycl/g/mpsstatgen/lucas/DLC/DLC_autoencoders/DeepOF/deepof/logs/"
losses = ["ELBO"] #, "MMD", "ELBO+MMD"]
losses = ["ELBO"] # , "MMD", "ELBO+MMD"]
encodings = [6] # [2, 4, 6, 8, 10, 12, 14, 16]
cluster_numbers = [25] # [1, 5, 10, 15, 20, 25]
latent_reg = ["none"] # ["none", "categorical", "variance", "categorical+variance"]
......
......@@ -60,11 +60,12 @@ def test_get_callbacks(
X_train=X_train,
batch_size=batch_size,
variational=variational,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss,
X_val=X_train,
input_type=False,
cp=True,
reg_cat_clusters=False,
reg_cluster_variance=False,
......@@ -179,11 +180,12 @@ def test_tune_search(
X_train=X_train,
batch_size=batch_size,
variational=(hypermodel == "S2SGMVAE"),
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
next_sequence_prediction=next_sequence_prediction,
rule_based_prediction=rule_based_prediction,
loss=loss,
X_val=X_train,
input_type=False,
cp=False,
reg_cat_clusters=True,
reg_cluster_variance=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment