diff --git a/source/model_utils.py b/source/model_utils.py
index cfbfa0ccda1396f906dcb47fe74d8ea3d237f4ae..53e7e86a29dd5abb12e335d9672dda11345f93bb 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -1,5 +1,6 @@
 # @author lucasmiranda42
 
+from itertools import combinations
 from keras import backend as K
 from sklearn.metrics import silhouette_score
 from tensorflow.keras.constraints import Constraint
@@ -22,7 +23,11 @@ def compute_kernel(x, y):
     )
 
 
-def compute_mmd(x, y):
+def compute_mmd(tensors):
+
+    x = tensors[0]
+    y = tensors[1]
+
     x_kernel = compute_kernel(x, x)
     y_kernel = compute_kernel(y, y)
     xy_kernel = compute_kernel(x, y)
@@ -127,7 +132,8 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
 
 
 class MMDiscrepancyLayer(Layer):
-    """ Identity transform layer that adds MM discrepancy
+    """
+    Identity transform layer that adds MM discrepancy
     to the final model loss.
     """
 
@@ -153,10 +159,78 @@ class MMDiscrepancyLayer(Layer):
         return z
 
 
+class Gaussian_mixture_overlap(Layer):
+    """
+    Identity layer that measures the overlap between the components of the latent Gaussian Mixture
+    using a specified metric (MMD, Wasserstein, Fischer-Rao)
+    """
+
+    def __init__(
+        self,
+        lat_dims,
+        n_components,
+        metric="mmd",
+        loss=False,
+        samples=100,
+        *args,
+        **kwargs
+    ):
+        self.lat_dims = lat_dims
+        self.n_components = n_components
+        self.metric = metric
+        self.loss = loss
+        self.samples = samples
+        super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
+
+    def get_config(self):
+        config = super().get_config().copy()
+        config.update({"lat_dims": self.lat_dims})
+        config.update({"n_components": self.n_components})
+        config.update({"metric": self.metric})
+        config.update({"loss": self.loss})
+        config.update({"samples": self.samples})
+        return config
+
+    def call(self, target, loss=False):
+
+        dists = []
+        for k in range(self.n_components):
+            locs = (target[..., : self.lat_dims, k],)
+            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
+
+            dists.append(tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]))
+
+        print(dists)
+        dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
+        print(dists)
+
+        if self.metric == "mmd":
+
+            intercomponent_mmd = K.mean(
+                tf.convert_to_tensor(
+                    [
+                        tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
+                        for c in combinations(range(len(dists)), 2)
+                    ],
+                    dtype=tf.float32,
+                )
+            )
+            print(intercomponent_mmd)
+            self.add_metric(
+                intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
+            )
+
+        elif self.metric == "wasserstein":
+            pass
+
+        return target
+
+
 class Latent_space_control(Layer):
-    """ Identity layer that adds latent space and clustering stats
-     to the metrics compiled by the model
-     """
+    """
+    Identity layer that adds latent space and clustering stats
+    to the metrics compiled by the model
+    """
 
     def __init__(self, *args, **kwargs):
         super(Latent_space_control, self).__init__(*args, **kwargs)
diff --git a/source/models.py b/source/models.py
index 587c5a953c490f9460fe60d9365465afea8667d9..21fd0e71edbe6811d0da64e85df19068c6a25961 100644
--- a/source/models.py
+++ b/source/models.py
@@ -167,6 +167,7 @@ class SEQ_2_SEQ_GMVAE:
         prior="standard_normal",
         number_of_components=1,
         predictor=True,
+        overlap_metric="mmd",
     ):
         self.input_shape = input_shape
         self.CONV_filters = units_conv
@@ -183,6 +184,7 @@ class SEQ_2_SEQ_GMVAE:
         self.mmd_warmup = mmd_warmup_epochs
         self.number_of_components = number_of_components
         self.predictor = predictor
+        self.overlap_metric = overlap_metric
 
         if self.prior == "standard_normal":
             self.prior = tfd.mixture.Mixture(
@@ -298,6 +300,10 @@ class SEQ_2_SEQ_GMVAE:
         )(encoder)
 
         z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
+        z_gauss = Gaussian_mixture_overlap(
+            self.ENCODING, self.number_of_components, metric=self.overlap_metric
+        )(z_gauss)
+
         z = tfpl.DistributionLambda(
             lambda gauss: tfd.mixture.Mixture(
                 cat=tfd.categorical.Categorical(probs=gauss[0],),
@@ -438,10 +444,10 @@ class SEQ_2_SEQ_GMVAE:
 
 
 # TODO:
-#       - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
+#       - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning (done!)
 #       - Clustering metrics for model selection and aid training (eg early stopping)
-#           - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics
-#       - design clustering-conscious hyperparameter tuing pipeline
+#           - Silhouette / mMMD / Fischer-Mao / Wasserstein
+#       - design clustering-conscious hyperparameter tuning pipeline
 
 # TODO (in the non-immediate future):
 #       - Try Bayesian nets!