Commit baabe478 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent acc06bfb
...@@ -292,13 +292,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback): ...@@ -292,13 +292,12 @@ class neighbor_latent_entropy(tf.keras.callbacks.Callback):
range(encoding.shape[0]), self.samples, replace=False range(encoding.shape[0]), self.samples, replace=False
) )
get_local_neighbourhood_entropy = partial( @tf.function
get_neighbourhood_entropy, def get_local_neighbourhood_entropy(index):
tensor=encodings, return get_neighbourhood_entropy(
clusters=hard_groups, index, tensor=encodings, clusters=hard_groups, k=self.k
k=self.k, )
dtype=tf.dtypes.float32,
)
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs) purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs)
writer = tf.summary.create_file_writer(self.log_dir) writer = tf.summary.create_file_writer(self.log_dir)
...@@ -596,13 +595,12 @@ class ClusterOverlap(Layer): ...@@ -596,13 +595,12 @@ class ClusterOverlap(Layer):
tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), self.samples tf.expand_dims(random_idxs / tf.reduce_sum(random_idxs), 0), self.samples
) )
get_local_neighbourhood_entropy = partial( @tf.function
get_neighbourhood_entropy, def get_local_neighbourhood_entropy(index):
tensor=encodings, return get_neighbourhood_entropy(
clusters=hard_groups, index, tensor=encodings, clusters=hard_groups, k=self.k
k=self.k, )
dtype=tf.dtypes.float32,
)
purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs) purity_vector = tf.map_fn(get_local_neighbourhood_entropy, random_idxs)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ### ### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment