Commit 6677665a authored by lucas_miranda's avatar lucas_miranda
Browse files

Modified cluster purity computation. Instead of KNN, we now look at...

Modified cluster purity computation. Instead of KNN, we now look at neighborhoods of a predefined radius
parent 0b91c7ca
Pipeline #95653 canceled with stages
in 9 minutes and 39 seconds
......@@ -109,16 +109,16 @@ parser.add_argument(
type=int,
)
parser.add_argument(
"--knn-neighbors",
"-knn",
help="Neighbors to take into account to compute KNN cluster purity",
"--entropy-radius",
"-entr",
help="radius of the neighborhood used to compute cluster purity",
default=100,
type=int,
)
parser.add_argument(
"--knn-samples",
"-knns",
help="Samples to use to compute KNN cluster purity",
"--entropy-samples",
"-ents",
help="Samples to use to compute cluster purity",
default=10000,
type=int,
)
......@@ -240,8 +240,8 @@ hparams = args.hyperparameters if args.hyperparameters is not None else {}
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
knn_neighbors = args.knn_neighbors
knn_samples = args.knn_samples
entropy_radius = args.entropy_radius
entropy_samples = args.entropy_samples
latent_reg = args.latent_reg
loss = args.loss
mmd_wu = args.mmd_warmup
......@@ -383,8 +383,8 @@ if not tune:
variational=variational,
reg_cat_clusters=("categorical" in latent_reg),
reg_cluster_variance=("variance" in latent_reg),
knn_neighbors=knn_neighbors,
knn_samples=knn_samples,
entropy_radius=entropy_radius,
entropy_samples=entropy_samples,
)
else:
......@@ -392,14 +392,14 @@ else:
hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, knn, onecycle = get_callbacks(
run_ID, tensorboard_callback, entropy, onecycle = get_callbacks(
X_train=X_train,
X_val=(X_val if X_val.shape != (0,) else None),
batch_size=batch_size,
cp=False,
variational=variational,
knn_samples=knn_samples,
knn_neighbors=knn_neighbors,
entropy_samples=entropy_samples,
entropy_radius=entropy_radius,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
......@@ -424,7 +424,7 @@ else:
callbacks=[
tensorboard_callback,
onecycle,
knn,
entropy,
CustomStopper(
monitor="val_loss",
patience=5,
......
......@@ -75,8 +75,8 @@ def get_callbacks(
cp: bool = False,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
knn_samples: int = 10000,
knn_neighbors: int = 100,
entropy_samples: int = 10000,
entropy_radius: float = 0.75,
logparam: dict = None,
outpath: str = ".",
) -> List[Union[Any]]:
......@@ -112,9 +112,9 @@ def get_callbacks(
profile_batch=2,
)
knn = deepof.model_utils.neighbor_cluster_purity(
k=knn_neighbors,
samples=knn_samples,
entropy = deepof.model_utils.neighbor_cluster_purity(
r=entropy_radius,
samples=entropy_samples,
validation_data=X_val,
log_dir=os.path.join(outpath, "metrics", run_ID),
variational=variational,
......@@ -126,7 +126,7 @@ def get_callbacks(
log_dir=os.path.join(outpath, "metrics", run_ID),
)
callbacks = [run_ID, tensorboard_callback, knn, onecycle]
callbacks = [run_ID, tensorboard_callback, entropy, onecycle]
if cp:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
......@@ -264,8 +264,8 @@ def autoencoder_fitting(
variational: bool,
reg_cat_clusters: bool,
reg_cluster_variance: bool,
knn_neighbors: int,
knn_samples: int,
entropy_radius: float,
entropy_samples: int,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -294,8 +294,8 @@ def autoencoder_fitting(
phenotype_class=phenotype_class,
predictor=predictor,
loss=loss,
knn_neighbors=knn_neighbors,
knn_samples=knn_samples,
entropy_radius=entropy_radius,
entropy_samples=entropy_samples,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
logparam=logparam,
......
......@@ -56,7 +56,7 @@ def test_get_callbacks(
pheno_class,
loss,
):
runID, tbc, knn, cycle1c, cpc = deepof.train_utils.get_callbacks(
runID, tbc, entropy, cycle1c, cpc = deepof.train_utils.get_callbacks(
X_train,
batch_size,
variational,
......@@ -71,7 +71,7 @@ def test_get_callbacks(
assert type(runID) == str
assert type(tbc) == tf.keras.callbacks.TensorBoard
assert type(cpc) == tf.keras.callbacks.ModelCheckpoint
assert type(knn) == deepof.model_utils.neighbor_cluster_purity
assert type(entropy) == deepof.model_utils.neighbor_cluster_purity
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment