Commit 22f291e6 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent a757402f
...@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor): ...@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function @tf.function
def get_k_nearest_neighbors(tensor, k, index): def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index""" """Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tf.gather(tensor, index) query = tf.gather(tensor, index, batch_dims=0)
distances = tf.norm(tensor - query, axis=1) distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k] max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance neighbourhood_mask = distances < max_distance
...@@ -49,7 +49,7 @@ def get_k_nearest_neighbors(tensor, k, index): ...@@ -49,7 +49,7 @@ def get_k_nearest_neighbors(tensor, k, index):
@tf.function @tf.function
def get_neighbourhood_entropy(index, tensor, clusters, k): def get_neighbourhood_entropy(index, tensor, clusters, k):
neighborhood = get_k_nearest_neighbors(tensor, k, index) neighborhood = get_k_nearest_neighbors(tensor, k, index)
cluster_z = tf.gather(clusters, neighborhood) cluster_z = tf.gather(clusters, neighborhood, batch_dims=0)
neigh_entropy = compute_shannon_entropy(cluster_z) neigh_entropy = compute_shannon_entropy(cluster_z)
return neigh_entropy return neigh_entropy
...@@ -473,7 +473,6 @@ class ClusterOverlap(Layer): ...@@ -473,7 +473,6 @@ class ClusterOverlap(Layer):
encoding_dim: int, encoding_dim: int,
k: int = 25, k: int = 25,
loss_weight: float = 0.0, loss_weight: float = 0.0,
samples: int = None,
*args, *args,
**kwargs **kwargs
): ):
...@@ -482,9 +481,6 @@ class ClusterOverlap(Layer): ...@@ -482,9 +481,6 @@ class ClusterOverlap(Layer):
self.k = k self.k = k
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.min_confidence = 0.25 self.min_confidence = 0.25
self.samples = samples
if self.samples is None:
self.samples = self.batch_size
super(ClusterOverlap, self).__init__(*args, **kwargs) super(ClusterOverlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover def get_config(self): # pragma: no cover
...@@ -507,11 +503,6 @@ class ClusterOverlap(Layer): ...@@ -507,11 +503,6 @@ class ClusterOverlap(Layer):
hard_groups = tf.math.argmax(categorical, axis=1) hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1) max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, self.batch_size])
random_idxs = range(self.batch_size)
random_idxs = np.random.choice(random_idxs, self.samples)
get_local_neighbourhood_entropy = partial( get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy, get_neighbourhood_entropy,
tensor=encodings, tensor=encodings,
...@@ -521,14 +512,12 @@ class ClusterOverlap(Layer): ...@@ -521,14 +512,12 @@ class ClusterOverlap(Layer):
purity_vector = tf.map_fn( purity_vector = tf.map_fn(
get_local_neighbourhood_entropy, get_local_neighbourhood_entropy,
tf.constant(random_idxs), tf.constant(list(range(self.batch_size))),
dtype=tf.dtypes.float32, dtype=tf.dtypes.float32,
) )
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ### ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * tf.gather( neighbourhood_entropy = purity_vector * max_groups
max_groups, tf.constant(random_idxs)
)
number_of_clusters = tf.cast( number_of_clusters = tf.cast(
tf.shape( tf.shape(
...@@ -537,6 +526,7 @@ class ClusterOverlap(Layer): ...@@ -537,6 +526,7 @@ class ClusterOverlap(Layer):
tf.gather( tf.gather(
tf.cast(hard_groups, tf.dtypes.float32), tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence), tf.where(max_groups >= self.min_confidence),
batch_dims=0,
), ),
[-1], [-1],
), ),
......
...@@ -475,9 +475,9 @@ class GMVAE: ...@@ -475,9 +475,9 @@ class GMVAE:
if self.number_of_components > 1: if self.number_of_components > 1:
z = deepof.model_utils.ClusterOverlap( z = deepof.model_utils.ClusterOverlap(
self.batch_size, batch_size=self.batch_size,
self.ENCODING, encoding_dim=self.ENCODING,
self.number_of_components, k=self.number_of_components,
loss_weight=self.overlap_loss, loss_weight=self.overlap_loss,
)([z, z_cat]) )([z, z_cat])
......
...@@ -86,7 +86,7 @@ ...@@ -86,7 +86,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n", "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n",
"trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights_280521\", \"var_annealing\")\n", "trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights_280521\", \"var_overlap_loss\")\n",
"exclude_bodyparts = tuple([\"\"])\n", "exclude_bodyparts = tuple([\"\"])\n",
"window_size = 24" "window_size = 24"
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment