Commit c36f55b7 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 2f627d63
...@@ -560,9 +560,9 @@ class ClusterOverlap(Layer): ...@@ -560,9 +560,9 @@ class ClusterOverlap(Layer):
self, self,
batch_size: int, batch_size: int,
encoding_dim: int, encoding_dim: int,
k: int = 100, k: int = 25,
loss_weight: float = 0.0, loss_weight: float = 0.0,
samples: int = 50, samples: int = None,
*args, *args,
**kwargs **kwargs
): ):
...@@ -572,6 +572,8 @@ class ClusterOverlap(Layer): ...@@ -572,6 +572,8 @@ class ClusterOverlap(Layer):
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.min_confidence = 0.25 self.min_confidence = 0.25
self.samples = samples self.samples = samples
if self.samples is None:
self.samples = self.batch_size
super(ClusterOverlap, self).__init__(*args, **kwargs) super(ClusterOverlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover def get_config(self): # pragma: no cover
...@@ -595,8 +597,8 @@ class ClusterOverlap(Layer): ...@@ -595,8 +597,8 @@ class ClusterOverlap(Layer):
max_groups = tf.reduce_max(categorical, axis=1) max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood # Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, self.batch_size]) # convert to numpy self.samples = np.min([self.samples, self.batch_size])
random_idxs = range(self.batch_size) # convert to batch size random_idxs = range(self.batch_size)
random_idxs = np.random.choice(random_idxs, self.samples) random_idxs = np.random.choice(random_idxs, self.samples)
get_local_neighbourhood_entropy = partial( get_local_neighbourhood_entropy = partial(
...@@ -628,9 +630,7 @@ class ClusterOverlap(Layer): ...@@ -628,9 +630,7 @@ class ClusterOverlap(Layer):
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy" neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
) )
# if self.loss_weight: if self.loss_weight:
# self.add_loss( self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
# self.loss_weight * neighbourhood_entropy, inputs=inputs
# )
return encodings return encodings
...@@ -324,7 +324,7 @@ def autoencoder_fitting( ...@@ -324,7 +324,7 @@ def autoencoder_fitting(
X_val_dataset = ( X_val_dataset = (
tf.data.Dataset.from_tensor_slices(X_val) tf.data.Dataset.from_tensor_slices(X_val)
.with_options(options) .with_options(options)
.batch(batch_size * strategy.num_replicas_in_sync) .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
) )
# Defines what to log on tensorboard (useful for trying out different models) # Defines what to log on tensorboard (useful for trying out different models)
...@@ -441,13 +441,13 @@ def autoencoder_fitting( ...@@ -441,13 +441,13 @@ def autoencoder_fitting(
# Convert data to tf.data.Dataset objects # Convert data to tf.data.Dataset objects
train_dataset = ( train_dataset = (
tf.data.Dataset.from_tensor_slices((Xs, tuple(ys))) tf.data.Dataset.from_tensor_slices((Xs, tuple(ys)))
.batch(batch_size * strategy.num_replicas_in_sync) .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.shuffle(buffer_size=X_train.shape[0]) .shuffle(buffer_size=X_train.shape[0])
.with_options(options) .with_options(options)
) )
val_dataset = ( val_dataset = (
tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals))) tf.data.Dataset.from_tensor_slices((Xvals, tuple(yvals)))
.batch(batch_size * strategy.num_replicas_in_sync) .batch(batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
.with_options(options) .with_options(options)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment