Commit 22f291e6 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent a757402f
......@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tf.gather(tensor, index)
query = tf.gather(tensor, index, batch_dims=0)
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
......@@ -49,7 +49,7 @@ def get_k_nearest_neighbors(tensor, k, index):
@tf.function
def get_neighbourhood_entropy(index, tensor, clusters, k):
neighborhood = get_k_nearest_neighbors(tensor, k, index)
cluster_z = tf.gather(clusters, neighborhood)
cluster_z = tf.gather(clusters, neighborhood, batch_dims=0)
neigh_entropy = compute_shannon_entropy(cluster_z)
return neigh_entropy
......@@ -473,7 +473,6 @@ class ClusterOverlap(Layer):
encoding_dim: int,
k: int = 25,
loss_weight: float = 0.0,
samples: int = None,
*args,
**kwargs
):
......@@ -482,9 +481,6 @@ class ClusterOverlap(Layer):
self.k = k
self.loss_weight = loss_weight
self.min_confidence = 0.25
self.samples = samples
if self.samples is None:
self.samples = self.batch_size
super(ClusterOverlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover
......@@ -507,11 +503,6 @@ class ClusterOverlap(Layer):
hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, self.batch_size])
random_idxs = range(self.batch_size)
random_idxs = np.random.choice(random_idxs, self.samples)
get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy,
tensor=encodings,
......@@ -521,14 +512,12 @@ class ClusterOverlap(Layer):
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy,
tf.constant(random_idxs),
tf.constant(list(range(self.batch_size))),
dtype=tf.dtypes.float32,
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * tf.gather(
max_groups, tf.constant(random_idxs)
)
neighbourhood_entropy = purity_vector * max_groups
number_of_clusters = tf.cast(
tf.shape(
......@@ -537,6 +526,7 @@ class ClusterOverlap(Layer):
tf.gather(
tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence),
batch_dims=0,
),
[-1],
),
......
......@@ -475,9 +475,9 @@ class GMVAE:
if self.number_of_components > 1:
z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING,
self.number_of_components,
batch_size=self.batch_size,
encoding_dim=self.ENCODING,
k=self.number_of_components,
loss_weight=self.overlap_loss,
)([z, z_cat])
......
......@@ -86,7 +86,7 @@
"outputs": [],
"source": [
"path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n",
"trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights_280521\", \"var_annealing\")\n",
"trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights_280521\", \"var_overlap_loss\")\n",
"exclude_bodyparts = tuple([\"\"])\n",
"window_size = 24"
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment