diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 9c895bfbee0603d25eb9a85269ae76bc53c23694..2b7b0a396caa0432572d3fe0dabacdcc2f0660c4 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -144,6 +144,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): start_rate: float = None, last_iterations: int = None, last_rate: float = None, + log_dir: str = ".", ): super().__init__() self.iterations = iterations @@ -154,6 +155,7 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): self.last_rate = last_rate or self.start_rate / 1000 self.iteration = 0 self.history = {} + self.log_dir = log_dir def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float: return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1 @@ -186,7 +188,15 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback): self.iteration += 1 K.set_value(self.model.optimizer.lr, rate) - logs["learning_rate"] = self.last_rate + def on_epoch_end(self, epoch, logs=None): + """Logs the learning rate to tensorboard""" + + writer = tf.summary.create_file_writer(self.log_dir) + + with writer.as_default(): + tf.summary.scalar( + "learning_rate", data=self.model.optimizer.lr, step=epoch, + ) class knn_cluster_purity(tf.keras.callbacks.Callback): @@ -196,13 +206,14 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): """ - def __init__(self, k=5, samples=10000): + def __init__(self, validation_data=None, k=100, samples=10000): super().__init__() + self.validation_data = validation_data self.k = k self.samples = samples # noinspection PyMethodOverriding,PyTypeChecker - def on_epoch_end(self, batch: int, logs): + def on_epoch_end(self, epoch, logs=None): """ Passes samples through the encoder and computes cluster purity on the latent embedding """ if self.validation_data is not None: @@ -255,7 +266,9 @@ class knn_cluster_purity(tf.keras.callbacks.Callback): * np.max(groups[sample]) ) - logs["knn_cluster_purity"] = purity_vector.mean() + tf.summary.scalar( + "knn_cluster_purity", data=purity_vector.mean(), step=epoch, + ) class uncorrelated_features_constraint(Constraint): diff --git a/deepof/train_utils.py b/deepof/train_utils.py index 9749e30273dab43572c0e18e37cdd52a2c6fc768..5a1c2d2f1ad587e994be52a1526aa129bbb4c7b6 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -119,6 +119,7 @@ def get_callbacks( onecycle = deepof.model_utils.one_cycle_scheduler( X_train.shape[0] // batch_size * 250, max_rate=0.005, + log_dir=os.path.join(outpath, "metrics") ) callbacks = [run_ID, tensorboard_callback, knn, onecycle]