diff --git a/model_training.py b/model_training.py
index c222b819e9c51be85795586f5dfed1a2149e6da0..4d3c0078b11d42beebd496e5bc3a3f900fa97d66 100644
--- a/model_training.py
+++ b/model_training.py
@@ -90,6 +90,20 @@ parser.add_argument(
     default=16,
     type=int,
 )
+parser.add_argument(
+    "--overlap-loss",
+    "-ol",
+    help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
+    default=False,
+    type=str2bool
+)
+parser.add_argument(
+    "--batch-size",
+    "-bs",
+    help="set training batch size. Defaults to 512",
+    type=int,
+    default=512
+)
 
 args = parser.parse_args()
 train_path = os.path.abspath(args.train_path)
@@ -103,6 +117,8 @@ kl_wu = args.kl_warmup
 mmd_wu = args.mmd_warmup
 hparams = args.hyperparameters
 encoding = args.encoding_size
+batch_size = args.batch_size
+overlap_loss = args.overlap_loss
 
 if not train_path:
     raise ValueError("Set a valid data path for the training to run")
@@ -372,7 +388,7 @@ if not variational:
         x=input_dict_train[input_type],
         y=input_dict_train[input_type],
         epochs=250,
-        batch_size=512,
+        batch_size=batch_size,
         verbose=1,
         validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
         callbacks=[
@@ -399,6 +415,7 @@ else:
         kl_warmup_epochs=kl_wu,
         mmd_warmup_epochs=mmd_wu,
         predictor=predictor,
+        overlap_loss=overlap_loss,
         **hparams
     ).build()
     gmvaep.build(input_dict_train[input_type].shape)
@@ -423,7 +440,7 @@ else:
             x=input_dict_train[input_type],
             y=input_dict_train[input_type],
             epochs=250,
-            batch_size=512,
+            batch_size=batch_size,
             verbose=1,
             validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
             callbacks=callbacks_,
@@ -433,7 +450,7 @@ else:
             x=input_dict_train[input_type][:-1],
             y=[input_dict_train[input_type][:-1], input_dict_train[input_type][1:]],
             epochs=250,
-            batch_size=512,
+            batch_size=batch_size,
             verbose=1,
             validation_data=(
                 input_dict_val[input_type][:-1],
diff --git a/source/model_utils.py b/source/model_utils.py
index b27cfd3fc4882ece1218630f5edc1dc010aef038..e705390be1040a97726547616885c3ec69bda4bc 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -2,6 +2,7 @@
 
 from itertools import combinations
 from keras import backend as K
+from scipy.stats import wasserstein_distance
 from sklearn.metrics import silhouette_score
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
@@ -133,24 +134,26 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
 
 class MMDiscrepancyLayer(Layer):
     """
-    Identity transform layer that adds MM discrepancy
+    Identity transform layer that adds MM Discrepancy
     to the final model loss.
     """
 
-    def __init__(self, prior, beta=1.0, *args, **kwargs):
+    def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
         self.is_placeholder = True
+        self.batch_size = batch_size
         self.beta = beta
         self.prior = prior
         super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
 
     def get_config(self):
         config = super().get_config().copy()
+        config.update({"batch_size": self.batch_size})
         config.update({"beta": self.beta})
         config.update({"prior": self.prior})
         return config
 
     def call(self, z, **kwargs):
-        true_samples = self.prior.sample(1)
+        true_samples = self.prior.sample(self.batch_size)
         mmd_batch = self.beta * compute_mmd([true_samples, z])
         self.add_loss(K.mean(mmd_batch), inputs=z)
         self.add_metric(mmd_batch, aggregation="mean", name="mmd")
@@ -166,18 +169,10 @@ class Gaussian_mixture_overlap(Layer):
     """
 
     def __init__(
-        self,
-        lat_dims,
-        n_components,
-        metric="mmd",
-        loss=False,
-        samples=100,
-        *args,
-        **kwargs
+        self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs
     ):
         self.lat_dims = lat_dims
         self.n_components = n_components
-        self.metric = metric
         self.loss = loss
         self.samples = samples
         super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
@@ -186,7 +181,6 @@ class Gaussian_mixture_overlap(Layer):
         config = super().get_config().copy()
         config.update({"lat_dims": self.lat_dims})
         config.update({"n_components": self.n_components})
-        config.update({"metric": self.metric})
         config.update({"loss": self.loss})
         config.update({"samples": self.samples})
         return config
@@ -204,27 +198,23 @@ class Gaussian_mixture_overlap(Layer):
 
         dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
 
-        if self.metric == "mmd":
-
-            intercomponent_mmd = K.mean(
-                tf.convert_to_tensor(
-                    [
-                        tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
-                        for c in combinations(range(len(dists)), 2)
-                    ],
-                    dtype=tf.float32,
-                )
-            )
-
-            self.add_metric(
-                intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
+        ### MMD-based overlap ###
+        intercomponent_mmd = K.mean(
+            tf.convert_to_tensor(
+                [
+                    tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
+                    for c in combinations(range(len(dists)), 2)
+                ],
+                dtype=tf.float32,
             )
+        )
 
-            if self.loss:
-                self.add_loss(-intercomponent_mmd, inputs=[target])
+        self.add_metric(
+            intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
+        )
 
-        elif self.metric == "wasserstein":
-            pass
+        if self.loss:
+            self.add_loss(-intercomponent_mmd, inputs=[target])
 
         return target
 
@@ -250,7 +240,7 @@ class Latent_space_control(Layer):
             tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
         )
 
-        # Adds Silhouette score controling overlap between clusters
+        # Adds Silhouette score controlling overlap between clusters
         hard_labels = tf.math.argmax(z_cat, axis=1)
         silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
         self.add_metric(silhouette, aggregation="mean", name="silhouette")
diff --git a/source/models.py b/source/models.py
index 90b69b349597b738bf67746bc50a4c276b8e00ac..afdf3843cb78f5b6a713194658de67dee6b4dcab 100644
--- a/source/models.py
+++ b/source/models.py
@@ -5,7 +5,7 @@ from tensorflow.keras import Input, Model, Sequential
 from tensorflow.keras.activations import softplus
 from tensorflow.keras.callbacks import LambdaCallback
 from tensorflow.keras.constraints import UnitNorm
-from tensorflow.keras.initializers import he_uniform, Orthogonal, RandomNormal
+from tensorflow.keras.initializers import he_uniform, Orthogonal
 from tensorflow.keras.layers import BatchNormalization, Bidirectional
 from tensorflow.keras.layers import Dense, Dropout, LSTM
 from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
@@ -155,6 +155,7 @@ class SEQ_2_SEQ_GMVAE:
     def __init__(
         self,
         input_shape,
+        batch_size=512,
         units_conv=256,
         units_lstm=256,
         units_dense2=64,
@@ -167,10 +168,10 @@ class SEQ_2_SEQ_GMVAE:
         prior="standard_normal",
         number_of_components=1,
         predictor=True,
-        overlap_metric="mmd",
         overlap_loss=False,
     ):
         self.input_shape = input_shape
+        self.batch_size = batch_size
         self.CONV_filters = units_conv
         self.LSTM_units_1 = units_lstm
         self.LSTM_units_2 = int(units_lstm / 2)
@@ -185,7 +186,6 @@ class SEQ_2_SEQ_GMVAE:
         self.mmd_warmup = mmd_warmup_epochs
         self.number_of_components = number_of_components
         self.predictor = predictor
-        self.overlap_metric = overlap_metric
         self.overlap_loss = overlap_loss
 
         if self.prior == "standard_normal":
@@ -303,10 +303,7 @@ class SEQ_2_SEQ_GMVAE:
 
         z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
         z_gauss = Gaussian_mixture_overlap(
-            self.ENCODING,
-            self.number_of_components,
-            metric=self.overlap_metric,
-            loss=self.overlap_loss,
+            self.ENCODING, self.number_of_components, loss=self.overlap_loss,
         )(z_gauss)
 
         z = tfpl.DistributionLambda(
@@ -353,10 +350,12 @@ class SEQ_2_SEQ_GMVAE:
                     )
                 )
 
-            z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
+            z = MMDiscrepancyLayer(
+                batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
+            )(z)
 
         # Identity layer controlling clustering and latent space statistics
-        z = Latent_space_control()(z, z_gauss, z_cat)
+        z = Latent_space_control(loss=self.overlap_loss)(z, z_gauss, z_cat)
 
         # Define and instantiate generator
         generator = Model_D1(z)