Commit 22f291e6 authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent a757402f
......@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@tf.function
def get_k_nearest_neighbors(tensor, k, index):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query = tf.gather(tensor, index)
query = tf.gather(tensor, index, batch_dims=0)
distances = tf.norm(tensor - query, axis=1)
max_distance = tf.sort(distances)[k]
neighbourhood_mask = distances < max_distance
......@@ -49,7 +49,7 @@ def get_k_nearest_neighbors(tensor, k, index):
@tf.function
def get_neighbourhood_entropy(index, tensor, clusters, k):
neighborhood = get_k_nearest_neighbors(tensor, k, index)
cluster_z = tf.gather(clusters, neighborhood)
cluster_z = tf.gather(clusters, neighborhood, batch_dims=0)
neigh_entropy = compute_shannon_entropy(cluster_z)
return neigh_entropy
......@@ -473,7 +473,6 @@ class ClusterOverlap(Layer):
encoding_dim: int,
k: int = 25,
loss_weight: float = 0.0,
samples: int = None,
*args,
**kwargs
):
......@@ -482,9 +481,6 @@ class ClusterOverlap(Layer):
self.k = k
self.loss_weight = loss_weight
self.min_confidence = 0.25
self.samples = samples
if self.samples is None:
self.samples = self.batch_size
super(ClusterOverlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover
......@@ -507,11 +503,6 @@ class ClusterOverlap(Layer):
hard_groups = tf.math.argmax(categorical, axis=1)
max_groups = tf.reduce_max(categorical, axis=1)
# Iterate over samples and compute purity across neighbourhood
self.samples = np.min([self.samples, self.batch_size])
random_idxs = range(self.batch_size)
random_idxs = np.random.choice(random_idxs, self.samples)
get_local_neighbourhood_entropy = partial(
get_neighbourhood_entropy,
tensor=encodings,
......@@ -521,14 +512,12 @@ class ClusterOverlap(Layer):
purity_vector = tf.map_fn(
get_local_neighbourhood_entropy,
tf.constant(random_idxs),
tf.constant(list(range(self.batch_size))),
dtype=tf.dtypes.float32,
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy = purity_vector * tf.gather(
max_groups, tf.constant(random_idxs)
)
neighbourhood_entropy = purity_vector * max_groups
number_of_clusters = tf.cast(
tf.shape(
......@@ -537,6 +526,7 @@ class ClusterOverlap(Layer):
tf.gather(
tf.cast(hard_groups, tf.dtypes.float32),
tf.where(max_groups >= self.min_confidence),
batch_dims=0,
),
[-1],
),
......
......@@ -475,9 +475,9 @@ class GMVAE:
if self.number_of_components > 1:
z = deepof.model_utils.ClusterOverlap(
self.batch_size,
self.ENCODING,
self.number_of_components,
batch_size=self.batch_size,
encoding_dim=self.ENCODING,
k=self.number_of_components,
loss_weight=self.overlap_loss,
)([z, z_cat])
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:markdown id: tags:
# deepOF model stability
%% Cell type:markdown id: tags:
Given a dataset and a set of trained models, this notebook allows the user to
* Group all weights according to their parameters
* Load the corresponding models
* Compute cluster assignment for a series of data points
* Compute and visualize the Adjusted Rand Index for each group
%% Cell type:code id: tags:
``` python
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:code id: tags:
``` python
import deepof.data
import deepof.utils
import numpy as np
import pandas as pd
import re
import tensorflow as tf
from itertools import combinations
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import adjusted_rand_score
```
%% Cell type:markdown id: tags:
### 1. Define and run project
%% Cell type:code id: tags:
``` python
path = os.path.join("..", "..", "Desktop", "deepof-data", "deepof_single_topview")
trained_network = os.path.join("..", "..", "Desktop", "deepof_trained_weights_280521", "var_annealing")
trained_network = os.path.join("..", "..", "Desktop", "deepof_trained_weights_280521", "var_overlap_loss")
exclude_bodyparts = tuple([""])
window_size = 24
```
%% Cell type:code id: tags:
``` python
%%time
proj = deepof.data.project(
path=path, smooth_alpha=0.999, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],
)
```
%% Output
CPU times: user 43.9 s, sys: 2.86 s, total: 46.8 s
Wall time: 39.1 s
%% Cell type:code id: tags:
``` python
%%time
proj = proj.run(verbose=True)
print(proj)
```
%% Output
Loading trajectories...
Smoothing trajectories...
Interpolating outliers...
Iterative imputation of ocluded bodyparts...
Computing distances...
Computing angles...
Done!
deepof analysis of 166 videos
CPU times: user 12min 34s, sys: 24.4 s, total: 12min 58s
Wall time: 3min 12s
%% Cell type:code id: tags:
``` python
coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
data = coords.preprocess(test_videos=0, window_step=24, window_size=window_size, shuffle=False)[
0
]
```
%% Cell type:code id: tags:
``` python
rand_idx = np.random.choice(range(data.shape[0]), 10000, replace=False)
data = data[rand_idx]
```
%% Cell type:markdown id: tags:
### 2. Load and group model weights
%% Cell type:code id: tags:
``` python
# Group based on training instance length
trained_weights = [os.path.join(trained_network, i) for i in os.listdir(trained_network) if i.endswith("h5")]
trained_weights_dict = {}
for tw in trained_weights:
added = False
warmup_mode = re.findall("_warmup_mode=(\w*)_", tw)[0]
length = re.findall("loss_warmup=(\d*)_", tw)[0]
rid = "{}_{}".format(warmup_mode, length)
for key in trained_weights_dict.keys():
if rid == key:
trained_weights_dict[key].append(tw)
added = True
if not added:
trained_weights_dict[rid] = [tw]
```
%% Cell type:markdown id: tags:
### 3. Load models and predict clusters for sampled data
%% Cell type:code id: tags:
``` python
groupings_dict = {}
```
%% Cell type:code id: tags:
``` python
def load_model_and_get_groupings(data, weights):
# Set model parameters
encoding = int(re.findall("encoding=(\d+)_", weights)[0])
k = int(re.findall("k=(\d+)_", weights)[0])
loss = re.findall("loss=(.+?)_", weights)[0]
NextSeqPred = float(re.findall("NextSeqPred=(.+?)_", weights)[0])
PhenoPred = float(re.findall("PhenoPred=(.+?)_", weights)[0])
RuleBasedPred = float(re.findall("RuleBasedPred=(.+?)_", weights)[0])
(
encode_to_vector,
decoder,
grouper,
gmvaep,
prior,
posterior,
) = deepof.models.SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
compile_model=True,
encoding=encoding,
next_sequence_prediction=NextSeqPred,
phenotype_prediction=PhenoPred,
rule_based_prediction=RuleBasedPred,
).build(
data.shape
)
gmvaep.load_weights(os.path.join(weights))
groups = grouper.predict(data)
return groups
```
%% Cell type:code id: tags:
``` python
for k,v in trained_weights_dict.items():
print(k)
for tw in tqdm(v):
if "NextSeqPred=0.15" in tw and "PhenoPred=0.0" in tw and "RuleBasedPred=0.0" in tw:
try:
groupings = load_model_and_get_groupings(data, tw)
try:
groupings_dict[k].append(groupings)
except KeyError:
groupings_dict[k] = [groupings]
except ValueError:
continue
```
%% Output
linear_25
linear_5
sigmoid_25
linear_20
linear_15
sigmoid_5
sigmoid_10
linear_10
sigmoid_20
sigmoid_15
%% Cell type:markdown id: tags:
### 4. Obtain ARI score and plot
%% Cell type:code id: tags:
``` python
hard_groups_dict = {
k: np.concatenate([np.argmax(i, axis=1)[:, np.newaxis] for i in v], axis=1)
for k, v in groupings_dict.items()
}
```
%% Cell type:code id: tags:
``` python
from collections import Counter
for k,v in hard_groups_dict.items():
print(k)
print(Counter(list(v[:,0])))
```
%% Output
linear_25
Counter({3: 3850, 1: 2234, 4: 1231, 8: 1059, 10: 542, 13: 528, 11: 332, 9: 85, 14: 43, 12: 34, 2: 30, 6: 25, 0: 7})
linear_5
Counter({5: 2458, 1: 2404, 6: 2026, 10: 1207, 2: 745, 3: 461, 11: 368, 8: 114, 9: 64, 14: 63, 0: 36, 13: 31, 7: 14, 4: 9})
sigmoid_25
Counter({2: 4310, 1: 2665, 3: 1553, 12: 762, 7: 649, 5: 60, 0: 1})
linear_20
Counter({4: 4328, 7: 2692, 11: 861, 10: 819, 1: 432, 3: 346, 6: 343, 12: 112, 5: 37, 0: 10, 2: 10, 13: 6, 9: 4})
linear_15
Counter({10: 3242, 12: 2511, 8: 1811, 2: 1755, 7: 299, 3: 254, 1: 126, 9: 1, 0: 1})
sigmoid_5
Counter({11: 5779, 12: 1503, 13: 943, 9: 626, 5: 551, 3: 369, 1: 122, 2: 62, 0: 32, 8: 8, 4: 4, 14: 1})
sigmoid_10
Counter({1: 2079, 13: 2063, 0: 1967, 6: 854, 8: 824, 2: 703, 5: 495, 4: 378, 10: 267, 7: 175, 11: 80, 14: 57, 3: 43, 9: 15})
linear_10
Counter({2: 4258, 8: 2366, 14: 2324, 12: 286, 7: 247, 1: 162, 3: 128, 5: 66, 9: 59, 11: 53, 0: 35, 10: 9, 6: 6, 4: 1})
sigmoid_20
Counter({6: 3750, 1: 2603, 11: 1572, 14: 571, 9: 494, 8: 457, 2: 395, 12: 94, 3: 34, 5: 18, 13: 9, 7: 2, 4: 1})
sigmoid_15
Counter({6: 5090, 13: 2582, 0: 662, 9: 586, 2: 451, 12: 266, 1: 175, 14: 66, 3: 51, 10: 45, 8: 24, 5: 2})
%% Cell type:code id: tags:
``` python
def extended_ARI(groupings):
comparisons = list(combinations(range(groupings.shape[1]), 2))
ari = [
adjusted_rand_score(groupings[:, comp[0]], groupings[:, comp[1]])
for comp in comparisons
]
return ari
```
%% Cell type:code id: tags:
``` python
ari_dict = {k: extended_ARI(v) for k, v in hard_groups_dict.items()}
simplified_keys = [
"overlap_loss=0.1",
"overlap_loss=0.2",
"overlap_loss=0.5",
"overlap_loss=0.75",
"overlap_loss=1.0",
]
ari_dict = {k:v for k,v in zip(simplified_keys, list(ari_dict.values()))}
ari_dict = {k:[i for i in v if i > 0] for k,v in ari_dict.items()}
```
%% Cell type:code id: tags:
``` python
for k,v in ari_dict.items():
if len(list(v)) < 45:
for i in range(45 - len(list(v))):
v.append(np.mean(np.array(v)))
ari_dict[k] = v
```
%% Cell type:code id: tags:
``` python
plt.figure(figsize=(12,8))
ari_df = pd.DataFrame(ari_dict).melt()
sns.boxplot(data=ari_df, x="value", y="variable")
#plt.xlim(0,1)
plt.savefig("deepof_variable_warmup.svg")
plt.show()
```
%% Output
%% Cell type:code id: tags:
``` python
```
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment