From 84c2e957c67119e715a1937b49cceebabc6f45e8 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Tue, 7 Jul 2020 18:36:53 +0200 Subject: [PATCH] Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option --- source/model_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/model_utils.py b/source/model_utils.py index cead8076..7918608a 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -168,7 +168,7 @@ class Gaussian_mixture_overlap(Layer): """ def __init__( - self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs + self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs ): self.lat_dims = lat_dims self.n_components = n_components @@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer): intercomponent_mmd = K.mean( tf.convert_to_tensor( [ - tf.map_fn(compute_mmd, [dists[c[0]], dists[c[1]]], dtype=tf.float32) + tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]]) for c in combinations(range(len(dists)), 2) ], dtype=tf.float32, -- GitLab