Commit 3393dcd1 authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 9219111b
......@@ -870,6 +870,8 @@ class coordinates:
variational: bool = True,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
knn_neighbors: int = 100,
knn_samples: int = 10000,
) -> Tuple:
"""
Annotates coordinates using an unsupervised autoencoder.
......@@ -930,6 +932,8 @@ class coordinates:
variational=variational,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
knn_neighbors=knn_neighbors,
knn_samples=knn_samples,
)
# returns a list of trained tensorflow models
......
......@@ -108,6 +108,20 @@ parser.add_argument(
default=10,
type=int,
)
parser.add_argument(
"--knn-neighbors",
"-knn",
help="Neighbors to take into account to compute KNN cluster purity",
default=100,
type=int,
)
parser.add_argument(
"--knn-samples",
"-knns",
help="Samples to use to compute KNN cluster purity",
default=10000,
type=int,
)
parser.add_argument(
"--latent-reg",
"-lreg",
......@@ -226,6 +240,8 @@ hparams = args.hyperparameters if args.hyperparameters is not None else {}
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
knn_neighbors = args.knn_neighbors
knn_samples = args.knn_samples
latent_reg = args.latent_reg
loss = args.loss
mmd_wu = args.mmd_warmup
......@@ -367,6 +383,8 @@ if not tune:
variational=variational,
reg_cat_clusters=("categorical" in latent_reg),
reg_cluster_variance=("variance" in latent_reg),
knn_neighbors=knn_neighbors,
knn_samples=knn_samples,
)
else:
......@@ -374,11 +392,13 @@ else:
hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, onecycle = get_callbacks(
run_ID, tensorboard_callback, knn, onecycle = get_callbacks(
X_train=X_train,
batch_size=batch_size,
cp=False,
variational=variational,
knn_samples=knn_samples,
knn_neighbors=knn_neighbors,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
......@@ -403,6 +423,7 @@ else:
callbacks=[
tensorboard_callback,
onecycle,
knn,
CustomStopper(
monitor="val_loss",
patience=5,
......
......@@ -74,6 +74,8 @@ def get_callbacks(
cp: bool = False,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
knn_samples: int = 10000,
knn_neighbors: int = 100,
logparam: dict = None,
outpath: str = ".",
) -> List[Union[Any]]:
......@@ -109,12 +111,17 @@ def get_callbacks(
profile_batch=2,
)
knn = deepof.model_utils.knn_cluster_purity(
k=knn_neighbors,
samples=knn_samples,
)
onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250,
max_rate=0.005,
)
callbacks = [run_ID, tensorboard_callback, onecycle]
callbacks = [run_ID, tensorboard_callback, knn, onecycle]
if cp:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
......@@ -252,6 +259,8 @@ def autoencoder_fitting(
variational: bool,
reg_cat_clusters: bool,
reg_cluster_variance: bool,
knn_neighbors: int,
knn_samples: int,
):
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
......@@ -279,6 +288,8 @@ def autoencoder_fitting(
phenotype_class=phenotype_class,
predictor=predictor,
loss=loss,
knn_neighbors=knn_neighbors,
knn_samples=knn_samples,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
logparam=logparam,
......
......@@ -56,7 +56,7 @@ def test_get_callbacks(
pheno_class,
loss,
):
runID, tbc, cycle1c, cpc = deepof.train_utils.get_callbacks(
runID, tbc, knn, cycle1c, cpc = deepof.train_utils.get_callbacks(
X_train,
batch_size,
variational,
......@@ -71,6 +71,7 @@ def test_get_callbacks(
assert type(runID) == str
assert type(tbc) == tf.keras.callbacks.TensorBoard
assert type(cpc) == tf.keras.callbacks.ModelCheckpoint
assert type(knn) == deepof.model_utils.knn_cluster_purity
assert type(cycle1c) == deepof.model_utils.one_cycle_scheduler
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment