Commit 31a78518 authored by lucas_miranda's avatar lucas_miranda
Browse files

Prototyped KNN_purity callback

parent 78d8baef
...@@ -144,6 +144,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): ...@@ -144,6 +144,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
start_rate: float = None, start_rate: float = None,
last_iterations: int = None, last_iterations: int = None,
last_rate: float = None, last_rate: float = None,
log_dir: str = ".",
): ):
super().__init__() super().__init__()
self.iterations = iterations self.iterations = iterations
...@@ -154,6 +155,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): ...@@ -154,6 +155,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.last_rate = last_rate or self.start_rate / 1000 self.last_rate = last_rate or self.start_rate / 1000
self.iteration = 0 self.iteration = 0
self.history = {} self.history = {}
self.log_dir = log_dir
def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float: def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1 return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1
...@@ -186,7 +188,15 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): ...@@ -186,7 +188,15 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
self.iteration += 1 self.iteration += 1
K.set_value(self.model.optimizer.lr, rate) K.set_value(self.model.optimizer.lr, rate)
logs["learning_rate"] = self.last_rate def on_epoch_end(self, epoch, logs=None):
"""Logs the learning rate to tensorboard"""
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
tf.summary.scalar(
"learning_rate", data=self.model.optimizer.lr, step=epoch,
)
class knn_cluster_purity(tf.keras.callbacks.Callback): class knn_cluster_purity(tf.keras.callbacks.Callback):
...@@ -196,13 +206,14 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -196,13 +206,14 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
""" """
def __init__(self, k=5, samples=10000): def __init__(self, validation_data=None, k=100, samples=10000):
super().__init__() super().__init__()
self.validation_data = validation_data
self.k = k self.k = k
self.samples = samples self.samples = samples
# noinspection PyMethodOverriding,PyTypeChecker # noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, batch: int, logs): def on_epoch_end(self, epoch, logs=None):
""" Passes samples through the encoder and computes cluster purity on the latent embedding """ """ Passes samples through the encoder and computes cluster purity on the latent embedding """
if self.validation_data is not None: if self.validation_data is not None:
...@@ -255,7 +266,9 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): ...@@ -255,7 +266,9 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
* np.max(groups[sample]) * np.max(groups[sample])
) )
logs["knn_cluster_purity"] = purity_vector.mean() tf.summary.scalar(
"knn_cluster_purity", data=purity_vector.mean(), step=epoch,
)
class uncorrelated_features_constraint(Constraint): class uncorrelated_features_constraint(Constraint):
......
...@@ -119,6 +119,7 @@ def get_callbacks( ...@@ -119,6 +119,7 @@ def get_callbacks(
onecycle = deepof.model_utils.one_cycle_scheduler( onecycle = deepof.model_utils.one_cycle_scheduler(
X_train.shape[0] // batch_size * 250, X_train.shape[0] // batch_size * 250,
max_rate=0.005, max_rate=0.005,
log_dir=os.path.join(outpath, "metrics")
) )
callbacks = [run_ID, tensorboard_callback, knn, onecycle] callbacks = [run_ID, tensorboard_callback, knn, onecycle]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment