Commit 84c2e957 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented mmd between Gaussian mixture components to track cluster...

Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option
parent 9e636771
......@@ -168,7 +168,7 @@ class Gaussian_mixture_overlap(Layer):
"""
def __init__(
self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs
self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs
):
self.lat_dims = lat_dims
self.n_components = n_components
......@@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer):
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
tf.map_fn(compute_mmd, [dists[c[0]], dists[c[1]]], dtype=tf.float32)
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
for c in combinations(range(len(dists)), 2)
],
dtype=tf.float32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment