From 84c2e957c67119e715a1937b49cceebabc6f45e8 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 7 Jul 2020 18:36:53 +0200
Subject: [PATCH] Implemented mmd between Gaussian mixture components to track
 cluster overlapping. Adding a term to the loss function is an option

---
 source/model_utils.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/source/model_utils.py b/source/model_utils.py
index cead8076..7918608a 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -168,7 +168,7 @@ class Gaussian_mixture_overlap(Layer):
     """
 
     def __init__(
-        self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs
+        self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs
     ):
         self.lat_dims = lat_dims
         self.n_components = n_components
@@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer):
         intercomponent_mmd = K.mean(
             tf.convert_to_tensor(
                 [
-                    tf.map_fn(compute_mmd, [dists[c[0]], dists[c[1]]], dtype=tf.float32)
+                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
                     for c in combinations(range(len(dists)), 2)
                 ],
                 dtype=tf.float32,
-- 
GitLab