Commit b9a11499 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent a7080a5c
Pipeline #101725 failed with stages
in 22 minutes and 25 seconds
......@@ -617,12 +617,19 @@ class ClusterOverlap(Layer):
)
self.add_metric(
tf.shape(
tf.unique(
tf.squeeze(tf.gather(hard_groups, tf.where(max_groups >= self.min_confidence)))
)
tf.cast(
tf.shape(
tf.unique(
tf.squeeze(
tf.gather(
tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence),
),
),
)[0],
)[0],
tf.dtypes.float32,
),
aggregation="mean",
name="number_of_populated_clusters",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment