Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
Lucas Miranda
deepOF
Commits
db3e8424
Commit
db3e8424
authored
Sep 17, 2020
by
lucas_miranda
Browse files
Added tests for deepof.visuals
parent
15b9eb4d
Changes
5
Hide whitespace changes
Inline
Side-by-side
deepof/preprocess.py
View file @
db3e8424
...
...
@@ -440,7 +440,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
astype
(
'
timedelta64[s]
'
)
).
astype
(
"
timedelta64[s]
"
)
if
align
:
assert
(
...
...
@@ -492,7 +492,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
astype
(
'
timedelta64[s]
'
)
).
astype
(
"
timedelta64[s]
"
)
return
table_dict
(
tabs
,
typ
=
"dists"
)
...
...
@@ -532,7 +532,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
astype
(
'
timedelta64[s]
'
)
).
astype
(
"
timedelta64[s]
"
)
return
table_dict
(
tabs
,
typ
=
"angles"
)
...
...
@@ -571,6 +571,7 @@ class coordinates:
return
self
.
_arena
,
self
.
_arena_dims
,
self
.
_scales
def
rule_based_annotation
(
self
):
"""Annotates coordinates using a simple rule-based pipeline"""
pass
...
...
@@ -634,10 +635,12 @@ class table_dict(dict):
else
[
0
,
self
.
_arena_dims
[
i
][
1
]]
)
plot_heatmap
(
heatmaps
=
plot_heatmap
(
list
(
self
.
values
())[
i
],
bodyparts
,
xlim
=
x_lim
,
ylim
=
y_lim
,
save
=
save
,
)
return
heatmaps
def
get_training_set
(
self
,
test_videos
:
int
=
0
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Generates training and test sets as numpy.array objects for model training"""
...
...
deepof/visuals.py
View file @
db3e8424
...
...
@@ -13,42 +13,36 @@ import numpy as np
import
pandas
as
pd
import
seaborn
as
sns
from
itertools
import
cycle
from
typing
import
List
,
Dict
from
typing
import
List
# PLOTTING FUNCTIONS #
def
plot_speed
(
behaviour_dict
:
Dict
[
str
,
pd
.
DataFrame
],
treatments
:
Dict
[
str
,
List
]
)
->
plt
.
figure
:
"""Plots a histogram with the speed of the specified mouse.
Treatments is expected to be a list of lists with mice keys per treatment"""
fig
,
[
ax1
,
ax2
]
=
plt
.
subplots
(
1
,
2
,
figsize
=
(
20
,
10
))
for
Treatment
,
Mice_list
in
treatments
.
items
():
hist
=
pd
.
concat
([
behaviour_dict
[
mouse
]
for
mouse
in
Mice_list
])
sns
.
kdeplot
(
hist
[
"bspeed"
],
shade
=
True
,
label
=
Treatment
,
ax
=
ax1
)
sns
.
kdeplot
(
hist
[
"wspeed"
],
shade
=
True
,
label
=
Treatment
,
ax
=
ax2
)
ax1
.
set_xlim
(
0
,
7
)
ax2
.
set_xlim
(
0
,
7
)
ax1
.
set_title
(
"Average speed density for black mouse"
)
ax2
.
set_title
(
"Average speed density for white mouse"
)
plt
.
xlabel
(
"Average speed"
)
plt
.
ylabel
(
"Density"
)
plt
.
show
()
def
plot_heatmap
(
dframe
:
pd
.
DataFrame
,
bodyparts
:
List
,
xlim
:
float
,
ylim
:
float
,
save
:
str
=
False
dframe
:
pd
.
DataFrame
,
bodyparts
:
List
,
xlim
:
tuple
,
ylim
:
tuple
,
save
:
str
=
False
,
dpi
:
int
=
200
,
)
->
plt
.
figure
:
"""Returns a heatmap of the movement of a specific bodypart in the arena.
If more than one bodypart is passed, it returns one subplot for each"""
If more than one bodypart is passed, it returns one subplot for each
Parameters:
- dframe (pandas.DataFrame): table_dict value with info to plot
- bodyparts (List): bodyparts to represent (at least 1)
- xlim (float): limits of the x-axis
- ylim (float): limits of the y-axis
- save (str): name of the file to which the figure should be saved
- dpi (int): dots per inch of the returned image
Returns:
- heatmaps (plt.figure): figure with the specified characteristics"""
# noinspection PyTypeChecker
fig
,
ax
=
plt
.
subplots
(
1
,
len
(
bodyparts
),
sharex
=
True
,
sharey
=
True
)
heatmaps
,
ax
=
plt
.
subplots
(
1
,
len
(
bodyparts
),
sharex
=
True
,
sharey
=
True
,
dpi
=
dpi
)
for
i
,
bpart
in
enumerate
(
bodyparts
):
heatmap
=
dframe
[
bpart
]
...
...
@@ -65,7 +59,7 @@ def plot_heatmap(
if
save
:
plt
.
savefig
(
save
)
plt
.
show
()
return
heatmaps
def
model_comparison_plot
(
...
...
@@ -73,17 +67,40 @@ def model_comparison_plot(
m_bic
:
list
,
n_components_range
:
range
,
cov_plot
:
str
,
save
:
str
,
save
:
str
=
False
,
cv_types
:
tuple
=
(
"spherical"
,
"tied"
,
"diag"
,
"full"
),
dpi
:
int
=
200
,
)
->
plt
.
figure
:
"""Plots model comparison statistics over all tests"""
"""
Plots model comparison statistics for Gaussian Mixture Model analysis.
Similar to https://scikit-learn.org/stable/modules/mixture.html, it shows
an upper panel with BIC per number of components and covariance matrix type
in a bar plot, and a lower panel with box plots showing bootstrap runs of the
models corresponding to one of the covariance types.
Parameters:
- bic (list): list with BIC for all used models
- m_bic (list): list with minimum bic across cov matrices
for all used models
- n_components_range (range): range of components to evaluate
- cov_plot (str): covariance matrix to use in the lower panel
- save (str): name of the file to which the figure should be saved
- cv_types (tuple): tuple indicating which covariance matrix types
to use. All (spherical, tied, diag and full) used by default.
- dpi (int): dots per inch of the returned image
Returns:
- modelcomp (plt.figure): figure with all specified characteristics
"""
m_bic
=
np
.
array
(
m_bic
)
color_iter
=
cycle
([
"navy"
,
"turquoise"
,
"cornflowerblue"
,
"darkorange"
])
bars
=
[]
# Plot the BIC scores
plt
.
figure
(
figsize
=
(
12
,
8
)
)
modelcomp
=
plt
.
figure
(
dpi
=
dpi
)
spl
=
plt
.
subplot
(
2
,
1
,
1
)
covplot
=
np
.
repeat
(
cv_types
,
len
(
m_bic
)
/
4
)
...
...
@@ -115,9 +132,7 @@ def model_comparison_plot(
spl2
.
set_xlabel
(
"Number of components"
)
spl2
.
set_ylabel
(
"BIC value"
)
plt
.
tight_layout
()
if
save
:
plt
.
savefig
(
save
)
plt
.
show
()
return
modelcomp
tests/test_model_utils.py
View file @
db3e8424
...
...
@@ -14,4 +14,4 @@ from hypothesis import strategies as st
from
collections
import
defaultdict
from
deepof.utils
import
*
import
deepof.preprocess
import
pytest
\ No newline at end of file
import
pytest
tests/test_preprocess.py
View file @
db3e8424
...
...
@@ -32,7 +32,7 @@ def test_project_init(table_type, arena_type):
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
arena_type
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
angles
=
False
,
video_format
=
".mp4"
,
table_format
=
table_type
,
...
...
@@ -41,7 +41,7 @@ def test_project_init(table_type, arena_type):
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
arena_type
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
angles
=
False
,
video_format
=
".mp4"
,
table_format
=
table_type
,
...
...
@@ -72,7 +72,7 @@ def test_get_distances(nodes, ego):
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
angles
=
False
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
...
...
@@ -98,7 +98,7 @@ def test_get_angles(nodes, ego):
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
distances
=
nodes
,
...
...
@@ -123,7 +123,7 @@ def test_run(nodes, ego):
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
distances
=
nodes
,
...
...
@@ -147,7 +147,7 @@ def test_get_table_dicts(nodes, ego, sampler):
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
distances
=
nodes
,
...
...
tests/test_visuals.py
View file @
db3e8424
...
...
@@ -11,7 +11,63 @@ Testing module for deepof.visuals
from
hypothesis
import
given
from
hypothesis
import
settings
from
hypothesis
import
strategies
as
st
from
collections
import
defaultdict
from
deepof.utils
import
*
import
deepof.preprocess
import
pytest
\ No newline at end of file
import
deepof.visuals
import
matplotlib.figure
def
test_plot_heatmap
():
prun
=
(
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
tuple
([
380
]),
angles
=
False
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
)
.
run
()
.
get_coords
()
)
assert
(
type
(
deepof
.
visuals
.
plot_heatmap
(
prun
[
"test"
],
[
"Center"
],
tuple
([
-
100
,
100
]),
tuple
([
-
100
,
100
]),
dpi
=
200
,
)
)
==
matplotlib
.
figure
.
Figure
)
def
test_model_comparison_plot
():
prun
=
(
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
tuple
([
380
]),
angles
=
False
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
)
.
run
()
.
get_coords
()
)
gmm_run
=
gmm_model_selection
(
prun
[
"test"
],
n_components_range
=
range
(
1
,
3
),
n_runs
=
1
,
part_size
=
100
)
assert
(
type
(
deepof
.
visuals
.
model_comparison_plot
(
gmm_run
[
0
],
gmm_run
[
1
],
range
(
1
,
3
),
cov_plot
=
"full"
)
)
==
matplotlib
.
figure
.
Figure
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment