diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index 9c895bfbee0603d25eb9a85269ae76bc53c23694..2b7b0a396caa0432572d3fe0dabacdcc2f0660c4 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -144,6 +144,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
         start_rate: float = None,
         last_iterations: int = None,
         last_rate: float = None,
+        log_dir: str = ".",
     ):
         super().__init__()
         self.iterations = iterations
@@ -154,6 +155,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
         self.last_rate = last_rate or self.start_rate / 1000
         self.iteration = 0
         self.history = {}
+        self.log_dir = log_dir
 
     def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
         return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1
@@ -186,7 +188,15 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
         self.iteration += 1
         K.set_value(self.model.optimizer.lr, rate)
 
-        logs["learning_rate"] = self.last_rate
+    def on_epoch_end(self, epoch, logs=None):
+        """Logs the learning rate to tensorboard"""
+
+        writer = tf.summary.create_file_writer(self.log_dir)
+
+        with writer.as_default():
+            tf.summary.scalar(
+                "learning_rate", data=self.model.optimizer.lr, step=epoch,
+            )
 
 
 class knn_cluster_purity(tf.keras.callbacks.Callback):
@@ -196,13 +206,14 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
 
     """
 
-    def __init__(self, k=5, samples=10000):
+    def __init__(self, validation_data=None, k=100, samples=10000):
         super().__init__()
+        self.validation_data = validation_data
         self.k = k
         self.samples = samples
 
     # noinspection PyMethodOverriding,PyTypeChecker
-    def on_epoch_end(self, batch: int, logs):
+    def on_epoch_end(self, epoch, logs=None):
         """ Passes samples through the encoder and computes cluster purity on the latent embedding """
 
         if self.validation_data is not None:
@@ -255,7 +266,9 @@ class knn_cluster_purity(tf.keras.callbacks.Callback):
                     * np.max(groups[sample])
                 )
 
-            logs["knn_cluster_purity"] = purity_vector.mean()
+            tf.summary.scalar(
+                "knn_cluster_purity", data=purity_vector.mean(), step=epoch,
+            )
 
 
 class uncorrelated_features_constraint(Constraint):
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 9749e30273dab43572c0e18e37cdd52a2c6fc768..5a1c2d2f1ad587e994be52a1526aa129bbb4c7b6 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -119,6 +119,7 @@ def get_callbacks(
     onecycle = deepof.model_utils.one_cycle_scheduler(
         X_train.shape[0] // batch_size * 250,
         max_rate=0.005,
+        log_dir=os.path.join(outpath, "metrics")
     )
 
     callbacks = [run_ID, tensorboard_callback, knn, onecycle]