Commit 4cf9fb3f authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent 8938e107
......@@ -619,21 +619,23 @@ class ClusterOverlap(Layer):
max_groups, tf.constant(random_idxs)
)
self.add_metric(
tf.cast(
tf.shape(
tf.unique(
tf.reshape(
tf.gather(
tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence),
),
[-1],
number_of_clusters = tf.cast(
tf.shape(
tf.unique(
tf.reshape(
tf.gather(
tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence),
),
)[0],
[-1],
),
)[0],
tf.dtypes.float32,
),
)[0],
tf.dtypes.float32,
)
self.add_metric(
number_of_clusters,
name="number_of_populated_clusters",
)
......@@ -648,6 +650,9 @@ class ClusterOverlap(Layer):
)
if self.loss_weight:
# minimize local entropy
self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
# maximize number of clusters
self.add_loss(-self.loss_weight * tf.reduce_mean(number_of_clusters))
return encodings
......@@ -28,7 +28,7 @@ phenotype_pred_weights = [0.0]
rule_based_pred_weights = [0.0]
window_lengths = [22] # range(11,56,11)
input_types = ["coords"]
run = list(range(1, 11))
run = list(range(1, 4))
rule deepof_experiments:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment