Commit 3c773f9f authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 9f0a9f31
......@@ -15,6 +15,7 @@ from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
......
......@@ -394,6 +394,7 @@ else:
run_ID, tensorboard_callback, knn, onecycle = get_callbacks(
X_train=X_train,
X_val=(X_val if X_val.shape != (0,) else None),
batch_size=batch_size,
cp=False,
variational=variational,
......
......@@ -71,6 +71,7 @@ def get_callbacks(
phenotype_class: float,
predictor: float,
loss: str,
X_val: np.array = None,
cp: bool = False,
reg_cat_clusters: bool = False,
reg_cluster_variance: bool = False,
......@@ -114,6 +115,7 @@ def get_callbacks(
knn = deepof.model_utils.knn_cluster_purity(
k=knn_neighbors,
samples=knn_samples,
validation_data=X_val,
)
onecycle = deepof.model_utils.one_cycle_scheduler(
......@@ -283,6 +285,7 @@ def autoencoder_fitting(
# Load callbacks
run_ID, *cbacks = get_callbacks(
X_train=X_train,
X_val=(X_val if X_val.shape != (0,) else None),
batch_size=batch_size,
cp=save_checkpoints,
variational=variational,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment