diff --git a/deepof/model_utils.py b/deepof/model_utils.py index b5763c28bdf4c90876f9ddfbf394435d477ef067..37356da85d05ccbe0496958e7d8e4e4384d3986e 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor): @tf.function def get_k_nearest_neighbors(tensor, k, index): """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index""" - query = tensor[index] + query = tf.slice(tensor, index, 1) distances = tf.norm(tensor - query, axis=1) max_distance = tf.sort(distances)[k] neighbourhood_mask = distances < max_distance @@ -583,7 +583,7 @@ class ClusterOverlap(Layer): config.update({"samples": self.samples}) return config - #@tf.function + @tf.function def call(self, inputs, **kwargs): """Updates Layer's call method""" @@ -594,8 +594,12 @@ class ClusterOverlap(Layer): # Iterate over samples and compute purity across neighbourhood self.samples = tf.reduce_min([self.samples, tf.shape(encodings)[0]]) - random_idxs = range(encodings.shape[0]) - random_idxs = np.random.choice(random_idxs, self.samples) + random_idxs = tf.range(tf.shape(encodings)[0]) + random_idxs = tf.random.categorical( + tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), + self.samples, + dtype=tf.dtypes.int32, + ) @tf.function def get_local_neighbourhood_entropy(index): @@ -604,7 +608,7 @@ class ClusterOverlap(Layer): ) purity_vector = tf.map_fn( - get_local_neighbourhood_entropy, tf.constant(random_idxs), dtype=tf.dtypes.float32 + get_local_neighbourhood_entropy, random_idxs, dtype=tf.dtypes.float32 ) ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###