Commit 31a78518 authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 78d8baef
......@@ -144,6 +144,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
start_rate: float = None,
last_iterations: int = None,
last_rate: float = None,
log_dir: str = ".",
):
super().__init__()
self.iterations = iterations
......@@ -154,6 +155,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.last_rate = last_rate or self.start_rate / 1000
self.iteration = 0
self.history = {}
self.log_dir = log_dir
def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1
......@@ -186,7 +188,15 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.iteration += 1
K.set_value(self.model.optimizer.lr, rate)
logs["learning_rate"] = self.last_rate
def on_epoch_end(self, epoch, logs=None):
"""Logs the learning rate to tensorboard"""
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
tf.summary.scalar(
"learning_rate", data=self.model.optimizer.lr, step=epoch,
)
class knn_cluster_purity(tf.keras.callbacks.Callback):
......@@ -196,13 +206,14 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
"""
def __init__(self, k=5, samples=10000):
def __init__(self, validation_data=None, k=100, samples=10000):
super().__init__()
self.validation_data = validation_data
self.k = k
self.samples = samples
# noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, batch: int, logs):
def on_epoch_end(self, epoch, logs=None):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """
if self.validation_data is not None:
......@@ -255,7 +266,9 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
* np.max(groups[sample])
)
logs["knn_cluster_purity"] = purity_vector.mean()
tf.summary.scalar(
"knn_cluster_purity", data=purity_vector.mean(), step=epoch,
)
class uncorrelated_features_constraint(Constraint):
......
......@@ -119,6 +119,7 @@ def get_callbacks(
onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250,
max_rate=0.005,
log_dir=os.path.join(outpath, "metrics")
)
callbacks = [run_ID, tensorboard_callback, knn, onecycle]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment