Commit 39be1caa authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent dae6a811
Pipeline #102635 passed with stages
in 18 minutes and 44 seconds
......@@ -495,65 +495,67 @@ class ClusterOverlap(Layer):
config.update({"samples": self.samples})
return config
def call(self, inputs, **kwargs):
def call(self, inputs, training=None, **kwargs):
"""Updates Layer's call method"""
encodings, categorical = inputs[0], inputs[1]
hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1)
if training:
get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy,
tensor=encodings,
clusters=hard_groups,
k=self.k,
)
hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1)
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy,
tf.constant(list(range(self.batch_size))),
dtype=tf.dtypes.float32,
)
get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy,
tensor=encodings,
clusters=hard_groups,
k=self.k,
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * max_groups
number_of_clusters = tf.cast(
tf.shape(
tf.unique(
tf.reshape(
tf.gather(
tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence),
batch_dims=0,
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy,
tf.constant(list(range(self.batch_size))),
dtype=tf.dtypes.float32,
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * max_groups
number_of_clusters = tf.cast(
tf.shape(
tf.unique(
tf.reshape(
tf.gather(
tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence),
batch_dims=0,
),
[-1],
),
[-1],
),
)[0],
)[0],
)[0],
tf.dtypes.float32,
)
tf.dtypes.float32,
)
self.add_metric(
number_of_clusters,
name="number_of_populated_clusters",
)
self.add_metric(
number_of_clusters,
name="number_of_populated_clusters",
)
self.add_metric(
max_groups,
aggregation="mean",
name="average_confidence_in_selected_cluster",
)
self.add_metric(
max_groups,
aggregation="mean",
name="average_confidence_in_selected_cluster",
)
self.add_metric(
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
)
self.add_metric(
neighbourhood_entropy, aggregation="mean", name="neighbourhood_entropy"
)
if self.loss_weight:
# minimize local entropy
self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
# maximize number of clusters
# self.add_loss(-self.loss_weight * tf.reduce_mean(number_of_clusters))
if self.loss_weight:
# minimize local entropy
self.add_loss(self.loss_weight * tf.reduce_mean(neighbourhood_entropy))
# maximize number of clusters
# self.add_loss(-self.loss_weight * tf.reduce_mean(number_of_clusters))
return encodings
......@@ -628,17 +628,9 @@
},
{
"cell_type": "code",
"execution_count": 193,
"execution_count": 195,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(\"cluster_overlap_44/Shape:0\", shape=(2,), dtype=int32)\n"
]
}
],
"outputs": [],
"source": [
"(\n",
" encoder,\n",
......@@ -664,101 +656,108 @@
},
{
"cell_type": "code",
"execution_count": 189,
"execution_count": 194,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(\"SEQ_2_SEQ_GMVAE/cluster_overlap_44/Shape:0\", shape=(2,), dtype=int32)\n"
]
},
{
"data": {
"text/plain": [
"array([[[-2.20274441e-02, -7.49340594e-01, -1.00606358e+00, ...,\n",
" 5.97163141e-01, 7.87078515e-02, 3.10142875e-01],\n",
" [-1.86636951e-02, -7.35167682e-01, -1.04070246e+00, ...,\n",
" 5.36489308e-01, 6.65854439e-02, 2.50531852e-01],\n",
" [-1.82860009e-02, -7.65824318e-01, -1.08549905e+00, ...,\n",
" 5.57325482e-01, 3.26592624e-02, 2.32314825e-01],\n",
"array([[[-2.33654119e-02, -7.18961537e-01, -1.04058743e+00, ...,\n",
" 6.04326367e-01, -8.63653272e-02, 3.18874061e-01],\n",
" [-2.12193187e-02, -7.04667985e-01, -1.08008337e+00, ...,\n",
" 5.50029635e-01, -9.52750891e-02, 2.61153668e-01],\n",
" [-2.03663222e-02, -7.36600041e-01, -1.12794685e+00, ...,\n",
" 5.78937709e-01, -1.28708437e-01, 2.42814869e-01],\n",
" ...,\n",
" [-6.10359153e-03, -7.89098859e-01, -1.30699337e+00, ...,\n",
" 6.69562757e-01, -4.62342173e-01, 3.77381265e-01],\n",
" [-4.65778168e-03, -8.11704040e-01, -1.28605306e+00, ...,\n",
" 7.13757634e-01, -4.67065245e-01, 4.34089184e-01],\n",
" [-6.71512820e-03, -8.53176653e-01, -1.29082215e+00, ...,\n",
" 7.71026969e-01, -5.02031207e-01, 4.69406664e-01]],\n",
" [-9.44020599e-03, -7.42647946e-01, -1.30498624e+00, ...,\n",
" 6.58422947e-01, -5.47523737e-01, 3.99996579e-01],\n",
" [-7.29610212e-03, -7.65350342e-01, -1.27779627e+00, ...,\n",
" 6.95867777e-01, -5.41299820e-01, 4.53622282e-01],\n",
" [-9.27564688e-03, -8.18071365e-01, -1.28233838e+00, ...,\n",
" 7.52779126e-01, -5.64222336e-01, 4.90554869e-01]],\n",
"\n",
" [[-2.22376399e-02, -7.70553827e-01, -9.43590641e-01, ...,\n",
" 5.78278065e-01, 2.24873498e-01, 3.14347565e-01],\n",
" [-1.82316527e-02, -7.55171597e-01, -9.75605845e-01, ...,\n",
" 5.09953022e-01, 2.11454228e-01, 2.45110869e-01],\n",
" [-1.79231763e-02, -7.78624117e-01, -1.02039325e+00, ...,\n",
" 5.19503474e-01, 1.74318790e-01, 2.36997604e-01],\n",
" [[-2.22253725e-02, -8.41416478e-01, -1.12681007e+00, ...,\n",
" 6.94921017e-01, -3.09668630e-01, 4.00050104e-01],\n",
" [-2.30817143e-02, -8.39012742e-01, -1.17822659e+00, ...,\n",
" 6.44349039e-01, -3.31067830e-01, 3.31252694e-01],\n",
" [-2.34456696e-02, -8.76340389e-01, -1.22851527e+00, ...,\n",
" 6.79111958e-01, -3.67583066e-01, 3.47556919e-01],\n",
" ...,\n",
" [-3.98765504e-03, -7.79918849e-01, -1.20278263e+00, ...,\n",
" 5.98293304e-01, -2.54329056e-01, 3.15500557e-01],\n",
" [-3.29605211e-03, -8.03981423e-01, -1.18459630e+00, ...,\n",
" 6.46142304e-01, -2.72717953e-01, 3.54292005e-01],\n",
" [-4.49736090e-03, -8.39645684e-01, -1.18976438e+00, ...,\n",
" 7.12202549e-01, -3.01250070e-01, 3.82732034e-01]],\n",
" [-1.15935905e-02, -8.33023310e-01, -1.16872275e+00, ...,\n",
" 6.26452863e-01, -3.46652746e-01, 3.66459996e-01],\n",
" [-7.99606554e-03, -8.50840628e-01, -1.14563727e+00, ...,\n",
" 6.63835168e-01, -3.26457262e-01, 3.89723182e-01],\n",
" [-8.63283966e-03, -9.02453482e-01, -1.13931096e+00, ...,\n",
" 7.31720626e-01, -3.24028820e-01, 4.03006196e-01]],\n",
"\n",
" [[-2.36156955e-02, -7.91016340e-01, -1.11879790e+00, ...,\n",
" 7.13154435e-01, -3.65756422e-01, 4.06932175e-01],\n",
" [-2.54947953e-02, -7.88331568e-01, -1.17535889e+00, ...,\n",
" 6.70483768e-01, -3.85737628e-01, 3.32547605e-01],\n",
" [-2.56239660e-02, -8.24208558e-01, -1.23989546e+00, ...,\n",
" 7.14474797e-01, -4.28516060e-01, 3.39852810e-01],\n",
" [[-2.67037526e-02, -7.14372754e-01, -1.01690137e+00, ...,\n",
" 5.67512035e-01, -7.96585530e-02, 2.87923396e-01],\n",
" [-2.66271308e-02, -7.22530246e-01, -1.08599091e+00, ...,\n",
" 5.39661050e-01, -1.06031761e-01, 2.33209401e-01],\n",
" [-2.72963755e-02, -7.53687024e-01, -1.14411652e+00, ...,\n",
" 5.67179024e-01, -1.51002675e-01, 2.36247778e-01],\n",
" ...,\n",
" [-1.38379335e-02, -7.91984320e-01, -1.10805249e+00, ...,\n",
" 6.95663571e-01, -3.35209459e-01, 3.72855306e-01],\n",
" [-8.09174031e-03, -8.20150554e-01, -1.07291603e+00, ...,\n",
" 7.31382430e-01, -3.07114869e-01, 3.89692128e-01],\n",
" [-7.56193697e-03, -8.81913245e-01, -1.08008969e+00, ...,\n",
" 7.96405435e-01, -2.96295941e-01, 4.00389731e-01]],\n",
" [-9.20080859e-03, -7.82252252e-01, -9.62321341e-01, ...,\n",
" 6.10508740e-01, -6.60327971e-02, 2.83200055e-01],\n",
" [-3.12909577e-03, -7.93930829e-01, -9.27874446e-01, ...,\n",
" 6.45921230e-01, -4.59667295e-02, 2.94262946e-01],\n",
" [-2.58489698e-03, -8.41321766e-01, -9.23104703e-01, ...,\n",
" 7.15546906e-01, -3.20333168e-02, 2.93513715e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.27419364e-03, -1.70991755e+00, -7.36668050e-01, ...,\n",
" 1.14400041e+00, 1.80199802e-01, 4.03101444e-01],\n",
" [ 9.55599174e-03, -1.77840245e+00, -7.31509030e-01, ...,\n",
" 1.10214877e+00, 1.92172885e-01, 3.62863183e-01],\n",
" [ 6.99162669e-03, -1.84455645e+00, -6.68665648e-01, ...,\n",
" 1.12744784e+00, 1.78093016e-01, 3.79255176e-01],\n",
" [[ 1.11031905e-03, -1.75031507e+00, -8.52379441e-01, ...,\n",
" 1.21988916e+00, 1.03995144e-01, 4.59409773e-01],\n",
" [ 8.80013034e-03, -1.83671558e+00, -8.69083345e-01, ...,\n",
" 1.19457781e+00, 1.23950966e-01, 4.14475203e-01],\n",
" [ 6.59639947e-03, -1.91086435e+00, -8.22446942e-01, ...,\n",
" 1.22027707e+00, 1.20491169e-01, 4.16057706e-01],\n",
" ...,\n",
" [ 1.46682467e-02, -1.67971981e+00, -5.29618204e-01, ...,\n",
" 1.07131672e+00, 2.70632327e-01, 3.96659851e-01],\n",
" [ 1.90188847e-02, -1.64219010e+00, -5.86603642e-01, ...,\n",
" 1.06212115e+00, 2.97838628e-01, 4.20529246e-01],\n",
" [ 1.77764259e-02, -1.60712445e+00, -5.87549746e-01, ...,\n",
" 1.08539033e+00, 3.32792282e-01, 4.31078553e-01]],\n",
" [ 1.49972383e-02, -1.74408889e+00, -5.97882569e-01, ...,\n",
" 1.21293759e+00, 4.05578047e-01, 5.18048167e-01],\n",
" [ 1.90434121e-02, -1.69634783e+00, -6.43523514e-01, ...,\n",
" 1.19492638e+00, 4.35988188e-01, 5.23940802e-01],\n",
" [ 1.63217653e-02, -1.64555275e+00, -6.33343041e-01, ...,\n",
" 1.19980395e+00, 4.42245930e-01, 5.15006006e-01]],\n",
"\n",
" [[ 5.71209192e-03, -1.57354665e+00, -8.53239119e-01, ...,\n",
" 1.10862482e+00, 1.20560825e-01, 4.82771069e-01],\n",
" [ 1.34701636e-02, -1.63361990e+00, -8.83263171e-01, ...,\n",
" 1.09039259e+00, 1.39529616e-01, 4.46821988e-01],\n",
" [ 1.27676763e-02, -1.69738317e+00, -8.67947996e-01, ...,\n",
" 1.11015606e+00, 1.38534516e-01, 4.43767011e-01],\n",
" [[ 2.67705321e-03, -1.70834327e+00, -7.96462834e-01, ...,\n",
" 1.17186177e+00, 1.32418081e-01, 4.49965090e-01],\n",
" [ 1.02437194e-02, -1.78479028e+00, -8.03244650e-01, ...,\n",
" 1.13687229e+00, 1.46574080e-01, 4.06593144e-01],\n",
" [ 7.09127262e-03, -1.85855865e+00, -7.63118446e-01, ...,\n",
" 1.15632105e+00, 1.33105025e-01, 4.12189573e-01],\n",
" ...,\n",
" [ 2.44371835e-02, -1.57842243e+00, -5.92557728e-01, ...,\n",
" 1.22556579e+00, 5.59510589e-01, 6.62321389e-01],\n",
" [ 2.63634715e-02, -1.54089200e+00, -6.52207255e-01, ...,\n",
" 1.19942415e+00, 5.94571948e-01, 6.66273296e-01],\n",
" [ 2.19873693e-02, -1.52749789e+00, -6.52528405e-01, ...,\n",
" 1.23534441e+00, 5.85497737e-01, 6.72029912e-01]],\n",
" [ 1.81343500e-02, -1.67693102e+00, -5.01757801e-01, ...,\n",
" 1.16755402e+00, 4.76390660e-01, 5.43144941e-01],\n",
" [ 2.21438985e-02, -1.63842762e+00, -5.54517329e-01, ...,\n",
" 1.15032852e+00, 5.08659005e-01, 5.53249776e-01],\n",
" [ 1.89741328e-02, -1.60885894e+00, -5.58988631e-01, ...,\n",
" 1.17562830e+00, 5.15346706e-01, 5.59121966e-01]],\n",
"\n",
" [[ 6.25598989e-03, -1.62235355e+00, -8.73309553e-01, ...,\n",
" 1.13996327e+00, 1.29154831e-01, 4.71493900e-01],\n",
" [ 1.32705066e-02, -1.68581510e+00, -9.02929783e-01, ...,\n",
" 1.12689865e+00, 1.41210854e-01, 4.33707178e-01],\n",
" [ 1.26603413e-02, -1.74876869e+00, -8.90606523e-01, ...,\n",
" 1.14830673e+00, 1.26072347e-01, 4.30647731e-01],\n",
" [[ 1.44676864e-03, -1.54619968e+00, -9.54818726e-01, ...,\n",
" 1.10726678e+00, -8.50950629e-02, 4.39189941e-01],\n",
" [ 6.98623247e-03, -1.59582186e+00, -9.81045246e-01, ...,\n",
" 1.09904671e+00, -8.24306607e-02, 3.94697845e-01],\n",
" [ 7.70333409e-03, -1.66229522e+00, -9.75158274e-01, ...,\n",
" 1.12477064e+00, -9.11763161e-02, 3.99113178e-01],\n",
" ...,\n",
" [ 2.57094558e-02, -1.64026582e+00, -6.49993658e-01, ...,\n",
" 1.23270714e+00, 4.76267278e-01, 6.13111377e-01],\n",
" [ 2.90074851e-02, -1.60244536e+00, -6.78030908e-01, ...,\n",
" 1.21040225e+00, 5.14402866e-01, 6.22853637e-01],\n",
" [ 2.53587011e-02, -1.59674680e+00, -6.75386965e-01, ...,\n",
" 1.23939729e+00, 5.17669916e-01, 6.38859391e-01]]],\n",
" [ 2.38139462e-02, -1.57800007e+00, -6.76846921e-01, ...,\n",
" 1.23930001e+00, 4.26834762e-01, 5.95292032e-01],\n",
" [ 2.62587424e-02, -1.54779446e+00, -7.09348798e-01, ...,\n",
" 1.21432519e+00, 4.73278880e-01, 5.99520683e-01],\n",
" [ 2.12610681e-02, -1.53637493e+00, -6.99584544e-01, ...,\n",
" 1.24682486e+00, 4.80561942e-01, 6.04046106e-01]]],\n",
" dtype=float32)"
]
},
"execution_count": 189,
"execution_count": 194,
"metadata": {},
"output_type": "execute_result"
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment