Commit 233e85d5 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated cluster entropy regularization

parent bdd9d251
......@@ -528,10 +528,11 @@ class MMDiscrepancyLayer(Layer):
return z
class Cluster_overlap(Layer):
class ClusterOverlap(Layer):
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using the average inter-cluster MMD as a metric
using the the entropy of the nearest neighbourhood. If self.loss_weight > 0, it adds a regularization
penalty to the loss function
"""
def __init__(
......@@ -546,8 +547,9 @@ class Cluster_overlap(Layer):
self.enc = encoding_dim
self.k = k
self.loss_weight = loss_weight
self.min_confidence = 0.25
self.samples = samples
super(Cluster_overlap, self).__init__(*args, **kwargs)
super(ClusterOverlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover
"""Updates Constraint metadata"""
......@@ -556,40 +558,59 @@ class Cluster_overlap(Layer):
config.update({"enc": self.enc})
config.update({"k": self.k})
config.update({"loss_weight": self.loss_weight})
config.update({"min_confidence": self.min_confidence})
config.update({"samples": self.samples})
return config
@tf.function
def call(self, target, **kwargs):
def call(self, encodings, categorical, **kwargs):
"""Updates Layer's call method"""
dists = []
for k in range(self.k):
locs = (target[..., : self.enc, k],)
scales = tf.keras.activations.softplus(target[..., self.enc :, k])
hard_groups = categorical.argmax(axis=1)
max_groups = categorical.max(axis=1)
dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
)
knn = NearestNeighbors().fit(encodings)
# Iterate over samples and compute purity across neighbourhood
self.samples = tf.reduce_min([self.samples, encodings.shape[0]])
random_idxs = range(encoding.shape[0])
random_idxs = tf.random.categorical(
tf.math.log([[i / sum(random_idxs) for i in random_idxs]]), 2
)
purity_vector = tf.zeros(self.samples)
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
for i, sample in enumerate(random_idxs):
neighborhood = knn.kneighbors(
encoding[sample][np.newaxis, :], self.k, return_distance=False
).flatten()
# MMD-based overlap #
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
for c in combinations(range(len(dists)), 2)
],
dtype=tf.float32,
)
z = hard_groups[neighborhood]
# Compute Shannon entropy across samples
neigh_entropy = entropy(np.bincount(z))
# Add result to pre allocated array
purity_vector[i] = neigh_entropy
neighbourhood_entropy = purity_vector * max_groups[random_idxs]
self.add_metric(
len(set(hard_groups[max_groups >= self.min_confidence])),
aggregation="mean",
name="number_of_populated_clusters",
)
self.add_metric(
max_groups,
aggregation="mean",
name="average_confidence_in_selected_cluster",
)
self.add_metric(
-intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
)
if self.loss_weight:
self.add_loss(-intercomponent_mmd, inputs=[target])
self.add_loss(neighbourhood_entropy, inputs=[target, categorical])
return target
return encodings
......@@ -390,13 +390,6 @@ class GMVAE:
),
)(encoder)
if self.overlap_loss:
z_cat = deepof.model_utils.Cluster_overlap(
self.ENCODING,
self.number_of_components,
loss_weight=self.overlap_loss,
)(z_cat)
z_gauss_mean = Dense(
tfpl.IndependentNormal.params_size(
self.ENCODING * self.number_of_components
......@@ -482,6 +475,13 @@ class GMVAE:
# Dummy layer with no parameters, to retrieve the previous tensor
z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
if self.overlap_loss:
z = deepof.model_utils.ClusterOverlap(
self.ENCODING,
self.number_of_components,
loss_weight=self.overlap_loss,
)([z, z_cat])
# Define and instantiate generator
g = Input(shape=self.ENCODING)
generator = Sequential(Model_D1)(g)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment