Commit 9f0a9f31 authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 31a78518
Pipeline #95526 canceled with stages
in 25 minutes and 22 seconds
......@@ -195,7 +195,9 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
with writer.as_default():
tf.summary.scalar(
"learning_rate", data=self.model.optimizer.lr, step=epoch,
"learning_rate",
data=self.model.optimizer.lr,
step=epoch,
)
......@@ -206,11 +208,12 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def __init__(self, validation_data=None, k=100, samples=10000):
def __init__(self, validation_data=None, k=100, samples=10000, log_dir="."):
super().__init__()
self.validation_data = validation_data
self.k = k
self.samples = samples
self.log_dir = log_dir
# noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, epoch, logs=None):
......@@ -266,8 +269,12 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
* np.max(groups[sample])
)
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
tf.summary.scalar(
"knn_cluster_purity", data=purity_vector.mean(), step=epoch,
"knn_cluster_purity",
data=purity_vector.mean(),
step=epoch,
)
......
......@@ -119,7 +119,7 @@ def get_callbacks(
onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250,
max_rate=0.005,
log_dir=os.path.join(outpath, "metrics")
log_dir=os.path.join(outpath, "metrics"),
)
callbacks = [run_ID, tensorboard_callback, knn, onecycle]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment