diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index f6b8acba81f9888cf3120b24f2202774e317d736..064d4cb2fbc9ee7c371ce71fc13da60f27b91cc4 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -560,9 +560,9 @@ class ClusterOverlap(Layer):
         self,
         batch_size: int,
         encoding_dim: int,
-        k: int = 100,
+        k: int = 25,
         loss_weight: float = 0.0,
-        samples: int = 50,
+        samples: int = None,
         *args,
         **kwargs
     ):
@@ -572,6 +572,8 @@ class ClusterOverlap(Layer):
         self.loss_weight = loss_weight
         self.min_confidence = 0.25
         self.samples = samples
+        if self.samples is None:
+            self.samples = self.batch_size
         super(ClusterOverlap, self).__init__(*args, **kwargs)
 
     def get_config(self):  # pragma: no cover
@@ -595,8 +597,8 @@ class ClusterOverlap(Layer):
         max_groups = tf.reduce_max(categorical, axis=1)
 
         # Iterate over samples and compute purity across neighbourhood
-        self.samples = np.min([self.samples, self.batch_size])  # convert to numpy
-        random_idxs = range(self.batch_size)  # convert to batch size
+        self.samples = np.min([self.samples, self.batch_size])
+        random_idxs = range(self.batch_size)
         random_idxs = np.random.choice(random_idxs, self.samples)
 
         get_local_neighbourhood_entropy = partial(
@@ -628,9 +630,7 @@ class ClusterOverlap(Layer):
             neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
         )
 
-        # if self.loss_weight:
-        #     self.add_loss(
-        #         self.loss_weight * neighbourhood_entropy, inputs=inputs
-        #     )
+        if self.loss_weight:
+            self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
 
         return encodings
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 9ac9f1dc1b38b8b4a5327f6c16c2cb19b32fb459..85042360817fc0a7ec72e849fba36a1d07cb841e 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -324,7 +324,7 @@ def autoencoder_fitting(
     X_val_dataset = (
         tf.data.Dataset.from_tensor_slices(X_val)
         .with_options(options)
-        .batch(batch_size * strategy.num_replicas_in_sync)
+        .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
     )
 
     # Defines what to log on tensorboard (useful for trying out different models)
@@ -441,13 +441,13 @@ def autoencoder_fitting(
     # Convert data to tf.data.Dataset objects
     train_dataset = (
         tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
-        .batch(batch_size * strategy.num_replicas_in_sync)
+        .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
         .shuffle(buffer_size=X_train.shape[0])
         .with_options(options)
     )
     val_dataset = (
         tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
-        .batch(batch_size * strategy.num_replicas_in_sync)
+        .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
         .with_options(options)
     )