Skip to content
Snippets Groups Projects
Commit 84c2e957 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Implemented mmd between Gaussian mixture components to track cluster...

Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option
parent 9e636771
No related branches found
No related tags found
No related merge requests found
...@@ -168,7 +168,7 @@ class Gaussian_mixture_overlap(Layer): ...@@ -168,7 +168,7 @@ class Gaussian_mixture_overlap(Layer):
""" """
def __init__( def __init__(
self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs
): ):
self.lat_dims = lat_dims self.lat_dims = lat_dims
self.n_components = n_components self.n_components = n_components
...@@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer): ...@@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer):
intercomponent_mmd = K.mean( intercomponent_mmd = K.mean(
tf.convert_to_tensor( tf.convert_to_tensor(
[ [
tf.map_fn(compute_mmd, [dists[c[0]], dists[c[1]]], dtype=tf.float32) tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
for c in combinations(range(len(dists)), 2) for c in combinations(range(len(dists)), 2)
], ],
dtype=tf.float32, dtype=tf.float32,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment