Skip to content
Snippets Groups Projects
Commit 84c2e957 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Implemented mmd between Gaussian mixture components to track cluster...

Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option
parent 9e636771
Branches
Tags
No related merge requests found
......@@ -168,7 +168,7 @@ class Gaussian_mixture_overlap(Layer):
"""
def __init__(
self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs
self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs
):
self.lat_dims = lat_dims
self.n_components = n_components
......@@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer):
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
tf.map_fn(compute_mmd, [dists[c[0]], dists[c[1]]], dtype=tf.float32)
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
for c in combinations(range(len(dists)), 2)
],
dtype=tf.float32,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment