diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 37356da85d05ccbe0496958e7d8e4e4384d3986e..dbe1d3b43f0f8c20aec230c5b7481b7711e2aa9c 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor): @tf.function def get_k_nearest_neighbors(tensor, k, index): """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index""" - query = tf.slice(tensor, index, 1) + query = tensor[index] distances = tf.norm(tensor - query, axis=1) max_distance = tf.sort(distances)[k] neighbourhood_mask = distances < max_distance