Skip to content
Snippets Groups Projects
Commit 9f0a9f31 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Prototyped KNN_purity callback

parent 31a78518
No related branches found
No related tags found
No related merge requests found
Pipeline #95526 canceled
...@@ -195,7 +195,9 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): ...@@ -195,7 +195,9 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
with writer.as_default(): with writer.as_default():
tf.summary.scalar( tf.summary.scalar(
"learning_rate", data=self.model.optimizer.lr, step=epoch, "learning_rate",
data=self.model.optimizer.lr,
step=epoch,
) )
...@@ -206,11 +208,12 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -206,11 +208,12 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
""" """
def __init__(self, validation_data=None, k=100, samples=10000): def __init__(self, validation_data=None, k=100, samples=10000, log_dir="."):
super().__init__() super().__init__()
self.validation_data = validation_data self.validation_data = validation_data
self.k = k self.k = k
self.samples = samples self.samples = samples
self.log_dir = log_dir
# noinspection PyMethodOverriding,PyTypeChecker # noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
...@@ -266,9 +269,13 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -266,9 +269,13 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
* np.max(groups[sample]) * np.max(groups[sample])
) )
tf.summary.scalar( writer = tf.summary.create_file_writer(self.log_dir)
"knn_cluster_purity", data=purity_vector.mean(), step=epoch, with writer.as_default():
) tf.summary.scalar(
"knn_cluster_purity",
data=purity_vector.mean(),
step=epoch,
)
class uncorrelated_features_constraint(Constraint): class uncorrelated_features_constraint(Constraint):
......
...@@ -119,7 +119,7 @@ def get_callbacks( ...@@ -119,7 +119,7 @@ def get_callbacks(
onecycle = deepof.model_utils.one_cycle_scheduler( onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250, X_train.shape[0] // batch_size * 250,
max_rate=0.005, max_rate=0.005,
log_dir=os.path.join(outpath, "metrics") log_dir=os.path.join(outpath, "metrics"),
) )
callbacks = [run_ID, tensorboard_callback, knn, onecycle] callbacks = [run_ID, tensorboard_callback, knn, onecycle]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment