Skip to content
Snippets Groups Projects
Commit c36f55b7 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 2f627d63
No related branches found
No related tags found
No related merge requests found
......@@ -560,9 +560,9 @@ class ClusterOverlap(Layer):
self,
batch_size: int,
encoding_dim: int,
k: int = 100,
k: int = 25,
loss_weight: float = 0.0,
samples: int = 50,
samples: int = None,
*args,
**kwargs
):
......@@ -572,6 +572,8 @@ class ClusterOverlap(Layer):
self.loss_weight = loss_weight
self.min_confidence = 0.25
self.samples = samples
if self.samples is None:
self.samples = self.batch_size
super(ClusterOverlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover
......@@ -595,8 +597,8 @@ class ClusterOverlap(Layer):
max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, self.batch_size]) # convert to numpy
random_idxs = range(self.batch_size) # convert to batch size
self.samples = np.min([self.samples, self.batch_size])
random_idxs = range(self.batch_size)
random_idxs = np.random.choice(random_idxs, self.samples)
get_local_neighbourhood_entropy = partial(
......@@ -628,9 +630,7 @@ class ClusterOverlap(Layer):
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
)
# if self.loss_weight:
# self.add_loss(
# self.loss_weight * neighbourhood_entropy, inputs=inputs
# )
if self.loss_weight:
self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
return encodings
......@@ -324,7 +324,7 @@ def autoencoder_fitting(
X_val_dataset = (
tf.data.Dataset.from_tensor_slices(X_val)
.with_options(options)
.batch(batch_size * strategy.num_replicas_in_sync)
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
)
# Defines what to log on tensorboard (useful for trying out different models)
......@@ -441,13 +441,13 @@ def autoencoder_fitting(
# Convert data to tf.data.Dataset objects
train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
.batch(batch_size * strategy.num_replicas_in_sync)
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.shuffle(buffer_size=X_train.shape[0])
.with_options(options)
)
val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
.batch(batch_size * strategy.num_replicas_in_sync)
.batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.with_options(options)
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment