From 6146894c57423229f02baf3a7f981a64e53bf836 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Wed, 19 May 2021 01:27:59 +0200
Subject: [PATCH] Replaced for loop with vectorised mapping on ClusterOverlap
 regularization layer

---
 deepof/model_utils.py | 16 ++++++++--------
 deepof/train_utils.py |  6 +++---
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index f6b8acba..064d4cb2 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -560,9 +560,9 @@ class ClusterOverlap(Layer):
         self,
         batch_size: int,
         encoding_dim: int,
-        k: int = 100,
+        k: int = 25,
         loss_weight: float = 0.0,
-        samples: int = 50,
+        samples: int = None,
         *args,
         **kwargs
     ):
@@ -572,6 +572,8 @@ class ClusterOverlap(Layer):
         self.loss_weight = loss_weight
         self.min_confidence = 0.25
         self.samples = samples
+        if self.samples is None:
+            self.samples = self.batch_size
         super(ClusterOverlap, self).__init__(*args, **kwargs)
 
     def get_config(self):  # pragma: no cover
@@ -595,8 +597,8 @@ class ClusterOverlap(Layer):
         max_groups = tf.reduce_max(categorical, axis=1)
 
         # Iterate over samples and compute purity across neighbourhood
-        self.samples = np.min([self.samples, self.batch_size])  # convert to numpy
-        random_idxs = range(self.batch_size)  # convert to batch size
+        self.samples = np.min([self.samples, self.batch_size])
+        random_idxs = range(self.batch_size)
         random_idxs = np.random.choice(random_idxs, self.samples)
 
         get_local_neighbourhood_entropy = partial(
@@ -628,9 +630,7 @@ class ClusterOverlap(Layer):
             neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
         )
 
-        # if self.loss_weight:
-        #     self.add_loss(
-        #         self.loss_weight * neighbourhood_entropy, inputs=inputs
-        #     )
+        if self.loss_weight:
+            self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
 
         return encodings
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index 9ac9f1dc..85042360 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -324,7 +324,7 @@ def autoencoder_fitting(
     X_val_dataset = (
         tf.data.Dataset.from_tensor_slices(X_val)
         .with_options(options)
-        .batch(batch_size * strategy.num_replicas_in_sync)
+        .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
     )
 
     # Defines what to log on tensorboard (useful for trying out different models)
@@ -441,13 +441,13 @@ def autoencoder_fitting(
     # Convert data to tf.data.Dataset objects
     train_dataset = (
         tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
-        .batch(batch_size * strategy.num_replicas_in_sync)
+        .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
         .shuffle(buffer_size=X_train.shape[0])
         .with_options(options)
     )
     val_dataset = (
         tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
-        .batch(batch_size * strategy.num_replicas_in_sync)
+        .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
         .with_options(options)
     )
 
-- 
GitLab