Commit 9aceb20e authored by lucas_miranda's avatar lucas_miranda
Browse files

Fixed a bug in model_utils.py that yielded nan entropy values when there were...

Fixed a bug in model_utils.py that yielded nan entropy values when there were no neighbors in the radius and the selected cluster was 0
parent de503fb3
......@@ -871,6 +871,7 @@ class coordinates:
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
entropy_samples: int = 10000,
entropy_min_n:int = 5,
) -> Tuple:
"""
Annotates coordinates using an unsupervised autoencoder.
......@@ -932,6 +933,7 @@ class coordinates:
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
entropy_samples=entropy_samples,
entropy_min_n=entropy_min_n,
)
# returns a list of trained tensorflow models
......
......@@ -108,6 +108,13 @@ parser.add_argument(
default=10,
type=int,
)
parser.add_argument(
"--entropy-min-n",
"-entminn",
help="Minimum number of neighbors in radius to take a sample into account when computing entropy",
default=5,
type=int,
)
parser.add_argument(
"--entropy-samples",
"-ents",
......@@ -233,6 +240,7 @@ hparams = args.hyperparameters if args.hyperparameters is not None else {}
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
entropy_min_n = args.entropy_min_n
entropy_samples = args.entropy_samples
latent_reg = args.latent_reg
loss = args.loss
......@@ -376,6 +384,7 @@ if not tune:
reg_cat_clusters=("categorical" in latent_reg),
reg_cluster_variance=("variance" in latent_reg),
entropy_samples=entropy_samples,
entropy_min_n=entropy_min_n,
)
else:
......@@ -390,6 +399,7 @@ else:
cp=False,
variational=variational,
entropy_samples=entropy_samples,
entropy_min_n=entropy_min_n,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
......
......@@ -76,6 +76,7 @@ def get_callbacks(
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
entropy_samples: int = 15000,
entropy_min_n: int = 5,
logparam: dict = None,
outpath: str = ".",
) -> List[Union[Any]]:
......@@ -117,6 +118,7 @@ def get_callbacks(
validation_data=X_val,
log_dir=os.path.join(outpath, "metrics", run_ID),
variational=variational,
min_n=entropy_min_n,
)
onecycle = deepof.model_utils.one_cycle_scheduler(
......@@ -264,6 +266,7 @@ def autoencoder_fitting(
reg_cat_clusters: bool,
reg_cluster_variance: bool,
entropy_samples: int,
entropy_min_n: int,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -293,6 +296,7 @@ def autoencoder_fitting(
predictor=predictor,
loss=loss,
entropy_samples=entropy_samples,
entropy_min_n=entropy_min_n,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
logparam=logparam,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment