Commit 4c4303ad authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 1089c048
......@@ -15,6 +15,7 @@ import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from functools import partial
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
......@@ -46,7 +47,7 @@ def get_k_nearest_neighbors(tensor, k, index):
@tf.function
def get_neighbourhood_entropy(tensor, clusters, k, index):
def get_neighbourhood_entropy(index, tensor, clusters, k):
neighborhood = get_k_nearest_neighbors(tensor, k, index)
cluster_z = tf.gather(clusters, neighborhood)
neigh_entropy = compute_shannon_entropy(cluster_z)
......@@ -291,13 +292,14 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
range(encoding.shape[0]), self.samples, replace=False
)
# Add result to pre allocated array
purity_vector = np.zeros(self.samples)
for i, sample in enumerate(random_idxs):
purity_vector[i] = get_neighbourhood_entropy(
encodings, hard_groups, self.k, sample
)
get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy,
tensor=encodings,
clusters=hard_groups,
k=self.k,
dtype=tf.dtypes.float32,
)
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs)
writer = tf.summary.create_file_writer(self.log_dir)
with writer.as_default():
......@@ -594,13 +596,16 @@ class ClusterOverlap(Layer):
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), self.samples
)
purity_vector = tf.map_fn(get_neighbourhood_entropy, random_idxs)
for i, sample in enumerate(random_idxs):
purity_vector[i] = get_neighbourhood_entropy(
encodings, hard_groups, self.k, sample
)
get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy,
tensor=encodings,
clusters=hard_groups,
k=self.k,
dtype=tf.dtypes.float32,
)
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * max_groups[random_idxs]
self.add_metric(
......
......@@ -426,9 +426,7 @@ class GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=1e-3
+ softplus(gauss[1][..., self.ENCODING :, k])
+ 1e-5,
scale=1e-3 + softplus(gauss[1][..., self.ENCODING :, k]),
),
reinterpreted_batch_ndims=1,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment