Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
22f291e6
Commit
22f291e6
authored
May 31, 2021
by
lucas_miranda
Browse files
Replaced for loop with vectorised mapping on ClusterOverlap regularization layer
parent
a757402f
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
22f291e6
...
...
@@ -39,7 +39,7 @@ def compute_shannon_entropy(tensor):
@
tf
.
function
def
get_k_nearest_neighbors
(
tensor
,
k
,
index
):
"""Retrieve indices of the k nearest neighbors in tensor to the vector with the specified index"""
query
=
tf
.
gather
(
tensor
,
index
)
query
=
tf
.
gather
(
tensor
,
index
,
batch_dims
=
0
)
distances
=
tf
.
norm
(
tensor
-
query
,
axis
=
1
)
max_distance
=
tf
.
sort
(
distances
)[
k
]
neighbourhood_mask
=
distances
<
max_distance
...
...
@@ -49,7 +49,7 @@ def get_k_nearest_neighbors(tensor, k, index):
@
tf
.
function
def
get_neighbourhood_entropy
(
index
,
tensor
,
clusters
,
k
):
neighborhood
=
get_k_nearest_neighbors
(
tensor
,
k
,
index
)
cluster_z
=
tf
.
gather
(
clusters
,
neighborhood
)
cluster_z
=
tf
.
gather
(
clusters
,
neighborhood
,
batch_dims
=
0
)
neigh_entropy
=
compute_shannon_entropy
(
cluster_z
)
return
neigh_entropy
...
...
@@ -473,7 +473,6 @@ class ClusterOverlap(Layer):
encoding_dim
:
int
,
k
:
int
=
25
,
loss_weight
:
float
=
0.0
,
samples
:
int
=
None
,
*
args
,
**
kwargs
):
...
...
@@ -482,9 +481,6 @@ class ClusterOverlap(Layer):
self
.
k
=
k
self
.
loss_weight
=
loss_weight
self
.
min_confidence
=
0.25
self
.
samples
=
samples
if
self
.
samples
is
None
:
self
.
samples
=
self
.
batch_size
super
(
ClusterOverlap
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
get_config
(
self
):
# pragma: no cover
...
...
@@ -507,11 +503,6 @@ class ClusterOverlap(Layer):
hard_groups
=
tf
.
math
.
argmax
(
categorical
,
axis
=
1
)
max_groups
=
tf
.
reduce_max
(
categorical
,
axis
=
1
)
# Iterate over samples and compute purity across neighbourhood
self
.
samples
=
np
.
min
([
self
.
samples
,
self
.
batch_size
])
random_idxs
=
range
(
self
.
batch_size
)
random_idxs
=
np
.
random
.
choice
(
random_idxs
,
self
.
samples
)
get_local_neighbourhood_entropy
=
partial
(
get_neighbourhood_entropy
,
tensor
=
encodings
,
...
...
@@ -521,14 +512,12 @@ class ClusterOverlap(Layer):
purity_vector
=
tf
.
map_fn
(
get_local_neighbourhood_entropy
,
tf
.
constant
(
random_idxs
),
tf
.
constant
(
list
(
range
(
self
.
batch_size
))
),
dtype
=
tf
.
dtypes
.
float32
,
)
### CANDIDATE FOR REMOVAL. EXPLORE HOW USEFUL THIS REALLY IS ###
neighbourhood_entropy
=
purity_vector
*
tf
.
gather
(
max_groups
,
tf
.
constant
(
random_idxs
)
)
neighbourhood_entropy
=
purity_vector
*
max_groups
number_of_clusters
=
tf
.
cast
(
tf
.
shape
(
...
...
@@ -537,6 +526,7 @@ class ClusterOverlap(Layer):
tf
.
gather
(
tf
.
cast
(
hard_groups
,
tf
.
dtypes
.
float32
),
tf
.
where
(
max_groups
>=
self
.
min_confidence
),
batch_dims
=
0
,
),
[
-
1
],
),
...
...
deepof/models.py
View file @
22f291e6
...
...
@@ -475,9 +475,9 @@ class GMVAE:
if
self
.
number_of_components
>
1
:
z
=
deepof
.
model_utils
.
ClusterOverlap
(
self
.
batch_size
,
self
.
ENCODING
,
self
.
number_of_components
,
batch_size
=
self
.
batch_size
,
encoding_dim
=
self
.
ENCODING
,
k
=
self
.
number_of_components
,
loss_weight
=
self
.
overlap_loss
,
)([
z
,
z_cat
])
...
...
supplementary_notebooks/deepof_explore_model_stability.ipynb
View file @
22f291e6
...
...
@@ -86,7 +86,7 @@
"outputs": [],
"source": [
"path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n",
"trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights_280521\", \"var_
annealing
\")\n",
"trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights_280521\", \"var_
overlap_loss
\")\n",
"exclude_bodyparts = tuple([\"\"])\n",
"window_size = 24"
]
...
...
%% Cell type:code id: tags:
```
python
%
load_ext
autoreload
%
autoreload
2
```
%% Cell type:code id: tags:
```
python
import
warnings
warnings
.
filterwarnings
(
"ignore"
)
```
%% Cell type:markdown id: tags:
# deepOF model stability
%% Cell type:markdown id: tags:
Given a dataset and a set of trained models, this notebook allows the user to
*
Group all weights according to their parameters
*
Load the corresponding models
*
Compute cluster assignment for a series of data points
*
Compute and visualize the Adjusted Rand Index for each group
%% Cell type:code id: tags:
```
python
import
os
os
.
chdir
(
os
.
path
.
dirname
(
"../"
))
```
%% Cell type:code id: tags:
```
python
import
deepof.data
import
deepof.utils
import
numpy
as
np
import
pandas
as
pd
import
re
import
tensorflow
as
tf
from
itertools
import
combinations
from
tqdm
import
tqdm_notebook
as
tqdm
import
matplotlib.pyplot
as
plt
import
seaborn
as
sns
from
sklearn.metrics
import
adjusted_rand_score
```
%% Cell type:markdown id: tags:
### 1. Define and run project
%% Cell type:code id: tags:
```
python
path
=
os
.
path
.
join
(
".."
,
".."
,
"Desktop"
,
"deepof-data"
,
"deepof_single_topview"
)
trained_network
=
os
.
path
.
join
(
".."
,
".."
,
"Desktop"
,
"deepof_trained_weights_280521"
,
"var_
annealing
"
)
trained_network
=
os
.
path
.
join
(
".."
,
".."
,
"Desktop"
,
"deepof_trained_weights_280521"
,
"var_
overlap_loss
"
)
exclude_bodyparts
=
tuple
([
""
])
window_size
=
24
```
%% Cell type:code id: tags:
```
python
%%
time
proj
=
deepof
.
data
.
project
(
path
=
path
,
smooth_alpha
=
0.999
,
exclude_bodyparts
=
exclude_bodyparts
,
arena_dims
=
[
380
],
)
```
%% Output
CPU times: user 43.9 s, sys: 2.86 s, total: 46.8 s
Wall time: 39.1 s
%% Cell type:code id: tags:
```
python
%%
time
proj
=
proj
.
run
(
verbose
=
True
)
print
(
proj
)
```
%% Output
Loading trajectories...
Smoothing trajectories...
Interpolating outliers...
Iterative imputation of ocluded bodyparts...
Computing distances...
Computing angles...
Done!
deepof analysis of 166 videos
CPU times: user 12min 34s, sys: 24.4 s, total: 12min 58s
Wall time: 3min 12s
%% Cell type:code id: tags:
```
python
coords
=
proj
.
get_coords
(
center
=
"Center"
,
align
=
"Spine_1"
,
align_inplace
=
True
)
data
=
coords
.
preprocess
(
test_videos
=
0
,
window_step
=
24
,
window_size
=
window_size
,
shuffle
=
False
)[
0
]
```
%% Cell type:code id: tags:
```
python
rand_idx
=
np
.
random
.
choice
(
range
(
data
.
shape
[
0
]),
10000
,
replace
=
False
)
data
=
data
[
rand_idx
]
```
%% Cell type:markdown id: tags:
### 2. Load and group model weights
%% Cell type:code id: tags:
```
python
# Group based on training instance length
trained_weights
=
[
os
.
path
.
join
(
trained_network
,
i
)
for
i
in
os
.
listdir
(
trained_network
)
if
i
.
endswith
(
"h5"
)]
trained_weights_dict
=
{}
for
tw
in
trained_weights
:
added
=
False
warmup_mode
=
re
.
findall
(
"_warmup_mode=(\w*)_"
,
tw
)[
0
]
length
=
re
.
findall
(
"loss_warmup=(\d*)_"
,
tw
)[
0
]
rid
=
"{}_{}"
.
format
(
warmup_mode
,
length
)
for
key
in
trained_weights_dict
.
keys
():
if
rid
==
key
:
trained_weights_dict
[
key
].
append
(
tw
)
added
=
True
if
not
added
:
trained_weights_dict
[
rid
]
=
[
tw
]
```
%% Cell type:markdown id: tags:
### 3. Load models and predict clusters for sampled data
%% Cell type:code id: tags:
```
python
groupings_dict
=
{}
```
%% Cell type:code id: tags:
```
python
def
load_model_and_get_groupings
(
data
,
weights
):
# Set model parameters
encoding
=
int
(
re
.
findall
(
"encoding=(\d+)_"
,
weights
)[
0
])
k
=
int
(
re
.
findall
(
"k=(\d+)_"
,
weights
)[
0
])
loss
=
re
.
findall
(
"loss=(.+?)_"
,
weights
)[
0
]
NextSeqPred
=
float
(
re
.
findall
(
"NextSeqPred=(.+?)_"
,
weights
)[
0
])
PhenoPred
=
float
(
re
.
findall
(
"PhenoPred=(.+?)_"
,
weights
)[
0
])
RuleBasedPred
=
float
(
re
.
findall
(
"RuleBasedPred=(.+?)_"
,
weights
)[
0
])
(
encode_to_vector
,
decoder
,
grouper
,
gmvaep
,
prior
,
posterior
,
)
=
deepof
.
models
.
SEQ_2_SEQ_GMVAE
(
loss
=
loss
,
number_of_components
=
k
,
compile_model
=
True
,
encoding
=
encoding
,
next_sequence_prediction
=
NextSeqPred
,
phenotype_prediction
=
PhenoPred
,
rule_based_prediction
=
RuleBasedPred
,
).
build
(
data
.
shape
)
gmvaep
.
load_weights
(
os
.
path
.
join
(
weights
))
groups
=
grouper
.
predict
(
data
)
return
groups
```
%% Cell type:code id: tags:
```
python
for
k
,
v
in
trained_weights_dict
.
items
():
print
(
k
)
for
tw
in
tqdm
(
v
):
if
"NextSeqPred=0.15"
in
tw
and
"PhenoPred=0.0"
in
tw
and
"RuleBasedPred=0.0"
in
tw
:
try
:
groupings
=
load_model_and_get_groupings
(
data
,
tw
)
try
:
groupings_dict
[
k
].
append
(
groupings
)
except
KeyError
:
groupings_dict
[
k
]
=
[
groupings
]
except
ValueError
:
continue
```
%% Output
linear_25
linear_5
sigmoid_25
linear_20
linear_15
sigmoid_5
sigmoid_10
linear_10
sigmoid_20
sigmoid_15
%% Cell type:markdown id: tags:
### 4. Obtain ARI score and plot
%% Cell type:code id: tags:
```
python
hard_groups_dict
=
{
k
:
np
.
concatenate
([
np
.
argmax
(
i
,
axis
=
1
)[:,
np
.
newaxis
]
for
i
in
v
],
axis
=
1
)
for
k
,
v
in
groupings_dict
.
items
()
}
```
%% Cell type:code id: tags:
```
python
from
collections
import
Counter
for
k
,
v
in
hard_groups_dict
.
items
():
print
(
k
)
print
(
Counter
(
list
(
v
[:,
0
])))
```
%% Output
linear_25
Counter({3: 3850, 1: 2234, 4: 1231, 8: 1059, 10: 542, 13: 528, 11: 332, 9: 85, 14: 43, 12: 34, 2: 30, 6: 25, 0: 7})
linear_5
Counter({5: 2458, 1: 2404, 6: 2026, 10: 1207, 2: 745, 3: 461, 11: 368, 8: 114, 9: 64, 14: 63, 0: 36, 13: 31, 7: 14, 4: 9})
sigmoid_25
Counter({2: 4310, 1: 2665, 3: 1553, 12: 762, 7: 649, 5: 60, 0: 1})
linear_20
Counter({4: 4328, 7: 2692, 11: 861, 10: 819, 1: 432, 3: 346, 6: 343, 12: 112, 5: 37, 0: 10, 2: 10, 13: 6, 9: 4})
linear_15
Counter({10: 3242, 12: 2511, 8: 1811, 2: 1755, 7: 299, 3: 254, 1: 126, 9: 1, 0: 1})
sigmoid_5
Counter({11: 5779, 12: 1503, 13: 943, 9: 626, 5: 551, 3: 369, 1: 122, 2: 62, 0: 32, 8: 8, 4: 4, 14: 1})
sigmoid_10
Counter({1: 2079, 13: 2063, 0: 1967, 6: 854, 8: 824, 2: 703, 5: 495, 4: 378, 10: 267, 7: 175, 11: 80, 14: 57, 3: 43, 9: 15})
linear_10
Counter({2: 4258, 8: 2366, 14: 2324, 12: 286, 7: 247, 1: 162, 3: 128, 5: 66, 9: 59, 11: 53, 0: 35, 10: 9, 6: 6, 4: 1})
sigmoid_20
Counter({6: 3750, 1: 2603, 11: 1572, 14: 571, 9: 494, 8: 457, 2: 395, 12: 94, 3: 34, 5: 18, 13: 9, 7: 2, 4: 1})
sigmoid_15
Counter({6: 5090, 13: 2582, 0: 662, 9: 586, 2: 451, 12: 266, 1: 175, 14: 66, 3: 51, 10: 45, 8: 24, 5: 2})
%% Cell type:code id: tags:
```
python
def
extended_ARI
(
groupings
):
comparisons
=
list
(
combinations
(
range
(
groupings
.
shape
[
1
]),
2
))
ari
=
[
adjusted_rand_score
(
groupings
[:,
comp
[
0
]],
groupings
[:,
comp
[
1
]])
for
comp
in
comparisons
]
return
ari
```
%% Cell type:code id: tags:
```
python
ari_dict
=
{
k
:
extended_ARI
(
v
)
for
k
,
v
in
hard_groups_dict
.
items
()}
simplified_keys
=
[
"overlap_loss=0.1"
,
"overlap_loss=0.2"
,
"overlap_loss=0.5"
,
"overlap_loss=0.75"
,
"overlap_loss=1.0"
,
]
ari_dict
=
{
k
:
v
for
k
,
v
in
zip
(
simplified_keys
,
list
(
ari_dict
.
values
()))}
ari_dict
=
{
k
:[
i
for
i
in
v
if
i
>
0
]
for
k
,
v
in
ari_dict
.
items
()}
```
%% Cell type:code id: tags:
```
python
for
k
,
v
in
ari_dict
.
items
():
if
len
(
list
(
v
))
<
45
:
for
i
in
range
(
45
-
len
(
list
(
v
))):
v
.
append
(
np
.
mean
(
np
.
array
(
v
)))
ari_dict
[
k
]
=
v
```
%% Cell type:code id: tags:
```
python
plt
.
figure
(
figsize
=
(
12
,
8
))
ari_df
=
pd
.
DataFrame
(
ari_dict
).
melt
()
sns
.
boxplot
(
data
=
ari_df
,
x
=
"value"
,
y
=
"variable"
)
#plt.xlim(0,1)
plt
.
savefig
(
"deepof_variable_warmup.svg"
)
plt
.
show
()
```
%% Output
%% Cell type:code id: tags:
```
python
```
...
...
supplementary_notebooks/deepof_model_evaluation.ipynb
View file @
22f291e6
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment