diff --git a/deepof/model_utils.py b/deepof/model_utils.py index b0a019cd73aba0a1ac6fac305d8f26f9b0c1ebf7..32c6352cc8536640bf5056ef6b6f89ecceff0c9f 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -266,11 +266,11 @@ class neighbor_cluster_purity(tf.keras.callbacks.Callback): for i, sample in enumerate(random_idxs): neighborhood = pdist[sample] < self.r + z = groups[neighborhood] + neigh_entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis) purity_vector[i] = ( - np.sum(hard_groups[neighborhood] == hard_groups[sample]) - / np.sum(neighborhood) - * np.max(groups[sample]) + neigh_entropy / np.sum(neighborhood) ) writer = tf.summary.create_file_writer(self.log_dir)