Commit 5b68125c authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 0958d797
......@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tf.slice(tensor, index, 1)
query = tensor[index]
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment