Commit 6146894c authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 9a805f04
......@@ -560,9 +560,9 @@ class ClusterOverlap(Layer):
self,
batch_size: int,
encoding_dim: int,
k: int = 100,
k: int = 25,
loss_weight: float = 0.0,
samples: int = 50,
samples: int = None,
*args,
**kwargs
):
......@@ -572,6 +572,8 @@ class ClusterOverlap(Layer):
self.loss_weight = loss_weight
self.min_confidence = 0.25
self.samples = samples
if self.samples is None:
self.samples = self.batch_size
super(ClusterOverlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover
......@@ -595,8 +597,8 @@ class ClusterOverlap(Layer):
max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, self.batch_size]) # convert to numpy
random_idxs = range(self.batch_size) # convert to batch size
self.samples = np.min([self.samples, self.batch_size])
random_idxs = range(self.batch_size)
random_idxs = np.random.choice(random_idxs, self.samples)
get_local_neighbourhood_entropy = partial(
......@@ -628,9 +630,7 @@ class ClusterOverlap(Layer):
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
)
# if self.loss_weight:
# self.add_loss(
# self.loss_weight * neighbourhood_entropy, inputs=inputs
# )
if self.loss_weight:
self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
return encodings
......@@ -324,7 +324,7 @@ def autoencoder_fitting(
X_val_dataset = (
tf.data.Dataset.from_tensor_slices(X_val)
.with_options(options)
.batch(batch_size * strategy.num_replicas_in_sync)
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
)
# Defines what to log on tensorboard (useful for trying out different models)
......@@ -441,13 +441,13 @@ def autoencoder_fitting(
# Convert data to tf.data.Dataset objects
train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
.batch(batch_size * strategy.num_replicas_in_sync)
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.shuffle(buffer_size=X_train.shape[0])
.with_options(options)
)
val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
.batch(batch_size * strategy.num_replicas_in_sync)
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.with_options(options)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment