Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
db3e8424
Commit
db3e8424
authored
Sep 17, 2020
by
lucas_miranda
Browse files
Added tests for deepof.visuals
parent
15b9eb4d
Changes
5
Hide whitespace changes
Inline
Side-by-side
deepof/preprocess.py
View file @
db3e8424
...
@@ -440,7 +440,7 @@ class coordinates:
...
@@ -440,7 +440,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
astype
(
'
timedelta64[s]
'
)
).
astype
(
"
timedelta64[s]
"
)
if
align
:
if
align
:
assert
(
assert
(
...
@@ -492,7 +492,7 @@ class coordinates:
...
@@ -492,7 +492,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
astype
(
'
timedelta64[s]
'
)
).
astype
(
"
timedelta64[s]
"
)
return
table_dict
(
tabs
,
typ
=
"dists"
)
return
table_dict
(
tabs
,
typ
=
"dists"
)
...
@@ -532,7 +532,7 @@ class coordinates:
...
@@ -532,7 +532,7 @@ class coordinates:
for
key
,
tab
in
tabs
.
items
():
for
key
,
tab
in
tabs
.
items
():
tabs
[
key
].
index
=
pd
.
timedelta_range
(
tabs
[
key
].
index
=
pd
.
timedelta_range
(
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
"00:00:00"
,
length
,
periods
=
tab
.
shape
[
0
]
+
1
,
closed
=
"left"
).
astype
(
'
timedelta64[s]
'
)
).
astype
(
"
timedelta64[s]
"
)
return
table_dict
(
tabs
,
typ
=
"angles"
)
return
table_dict
(
tabs
,
typ
=
"angles"
)
...
@@ -571,6 +571,7 @@ class coordinates:
...
@@ -571,6 +571,7 @@ class coordinates:
return
self
.
_arena
,
self
.
_arena_dims
,
self
.
_scales
return
self
.
_arena
,
self
.
_arena_dims
,
self
.
_scales
def
rule_based_annotation
(
self
):
def
rule_based_annotation
(
self
):
"""Annotates coordinates using a simple rule-based pipeline"""
pass
pass
...
@@ -634,10 +635,12 @@ class table_dict(dict):
...
@@ -634,10 +635,12 @@ class table_dict(dict):
else
[
0
,
self
.
_arena_dims
[
i
][
1
]]
else
[
0
,
self
.
_arena_dims
[
i
][
1
]]
)
)
plot_heatmap
(
heatmaps
=
plot_heatmap
(
list
(
self
.
values
())[
i
],
bodyparts
,
xlim
=
x_lim
,
ylim
=
y_lim
,
save
=
save
,
list
(
self
.
values
())[
i
],
bodyparts
,
xlim
=
x_lim
,
ylim
=
y_lim
,
save
=
save
,
)
)
return
heatmaps
def
get_training_set
(
self
,
test_videos
:
int
=
0
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
def
get_training_set
(
self
,
test_videos
:
int
=
0
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Generates training and test sets as numpy.array objects for model training"""
"""Generates training and test sets as numpy.array objects for model training"""
...
...
deepof/visuals.py
View file @
db3e8424
...
@@ -13,42 +13,36 @@ import numpy as np
...
@@ -13,42 +13,36 @@ import numpy as np
import
pandas
as
pd
import
pandas
as
pd
import
seaborn
as
sns
import
seaborn
as
sns
from
itertools
import
cycle
from
itertools
import
cycle
from
typing
import
List
,
Dict
from
typing
import
List
# PLOTTING FUNCTIONS #
# PLOTTING FUNCTIONS #
def
plot_speed
(
behaviour_dict
:
Dict
[
str
,
pd
.
DataFrame
],
treatments
:
Dict
[
str
,
List
]
)
->
plt
.
figure
:
"""Plots a histogram with the speed of the specified mouse.
Treatments is expected to be a list of lists with mice keys per treatment"""
fig
,
[
ax1
,
ax2
]
=
plt
.
subplots
(
1
,
2
,
figsize
=
(
20
,
10
))
for
Treatment
,
Mice_list
in
treatments
.
items
():
hist
=
pd
.
concat
([
behaviour_dict
[
mouse
]
for
mouse
in
Mice_list
])
sns
.
kdeplot
(
hist
[
"bspeed"
],
shade
=
True
,
label
=
Treatment
,
ax
=
ax1
)
sns
.
kdeplot
(
hist
[
"wspeed"
],
shade
=
True
,
label
=
Treatment
,
ax
=
ax2
)
ax1
.
set_xlim
(
0
,
7
)
ax2
.
set_xlim
(
0
,
7
)
ax1
.
set_title
(
"Average speed density for black mouse"
)
ax2
.
set_title
(
"Average speed density for white mouse"
)
plt
.
xlabel
(
"Average speed"
)
plt
.
ylabel
(
"Density"
)
plt
.
show
()
def
plot_heatmap
(
def
plot_heatmap
(
dframe
:
pd
.
DataFrame
,
bodyparts
:
List
,
xlim
:
float
,
ylim
:
float
,
save
:
str
=
False
dframe
:
pd
.
DataFrame
,
bodyparts
:
List
,
xlim
:
tuple
,
ylim
:
tuple
,
save
:
str
=
False
,
dpi
:
int
=
200
,
)
->
plt
.
figure
:
)
->
plt
.
figure
:
"""Returns a heatmap of the movement of a specific bodypart in the arena.
"""Returns a heatmap of the movement of a specific bodypart in the arena.
If more than one bodypart is passed, it returns one subplot for each"""
If more than one bodypart is passed, it returns one subplot for each
Parameters:
- dframe (pandas.DataFrame): table_dict value with info to plot
- bodyparts (List): bodyparts to represent (at least 1)
- xlim (float): limits of the x-axis
- ylim (float): limits of the y-axis
- save (str): name of the file to which the figure should be saved
- dpi (int): dots per inch of the returned image
Returns:
- heatmaps (plt.figure): figure with the specified characteristics"""
# noinspection PyTypeChecker
# noinspection PyTypeChecker
fig
,
ax
=
plt
.
subplots
(
1
,
len
(
bodyparts
),
sharex
=
True
,
sharey
=
True
)
heatmaps
,
ax
=
plt
.
subplots
(
1
,
len
(
bodyparts
),
sharex
=
True
,
sharey
=
True
,
dpi
=
dpi
)
for
i
,
bpart
in
enumerate
(
bodyparts
):
for
i
,
bpart
in
enumerate
(
bodyparts
):
heatmap
=
dframe
[
bpart
]
heatmap
=
dframe
[
bpart
]
...
@@ -65,7 +59,7 @@ def plot_heatmap(
...
@@ -65,7 +59,7 @@ def plot_heatmap(
if
save
:
if
save
:
plt
.
savefig
(
save
)
plt
.
savefig
(
save
)
plt
.
show
()
return
heatmaps
def
model_comparison_plot
(
def
model_comparison_plot
(
...
@@ -73,17 +67,40 @@ def model_comparison_plot(
...
@@ -73,17 +67,40 @@ def model_comparison_plot(
m_bic
:
list
,
m_bic
:
list
,
n_components_range
:
range
,
n_components_range
:
range
,
cov_plot
:
str
,
cov_plot
:
str
,
save
:
str
,
save
:
str
=
False
,
cv_types
:
tuple
=
(
"spherical"
,
"tied"
,
"diag"
,
"full"
),
cv_types
:
tuple
=
(
"spherical"
,
"tied"
,
"diag"
,
"full"
),
dpi
:
int
=
200
,
)
->
plt
.
figure
:
)
->
plt
.
figure
:
"""Plots model comparison statistics over all tests"""
"""
Plots model comparison statistics for Gaussian Mixture Model analysis.
Similar to https://scikit-learn.org/stable/modules/mixture.html, it shows
an upper panel with BIC per number of components and covariance matrix type
in a bar plot, and a lower panel with box plots showing bootstrap runs of the
models corresponding to one of the covariance types.
Parameters:
- bic (list): list with BIC for all used models
- m_bic (list): list with minimum bic across cov matrices
for all used models
- n_components_range (range): range of components to evaluate
- cov_plot (str): covariance matrix to use in the lower panel
- save (str): name of the file to which the figure should be saved
- cv_types (tuple): tuple indicating which covariance matrix types
to use. All (spherical, tied, diag and full) used by default.
- dpi (int): dots per inch of the returned image
Returns:
- modelcomp (plt.figure): figure with all specified characteristics
"""
m_bic
=
np
.
array
(
m_bic
)
m_bic
=
np
.
array
(
m_bic
)
color_iter
=
cycle
([
"navy"
,
"turquoise"
,
"cornflowerblue"
,
"darkorange"
])
color_iter
=
cycle
([
"navy"
,
"turquoise"
,
"cornflowerblue"
,
"darkorange"
])
bars
=
[]
bars
=
[]
# Plot the BIC scores
# Plot the BIC scores
plt
.
figure
(
figsize
=
(
12
,
8
)
)
modelcomp
=
plt
.
figure
(
dpi
=
dpi
)
spl
=
plt
.
subplot
(
2
,
1
,
1
)
spl
=
plt
.
subplot
(
2
,
1
,
1
)
covplot
=
np
.
repeat
(
cv_types
,
len
(
m_bic
)
/
4
)
covplot
=
np
.
repeat
(
cv_types
,
len
(
m_bic
)
/
4
)
...
@@ -115,9 +132,7 @@ def model_comparison_plot(
...
@@ -115,9 +132,7 @@ def model_comparison_plot(
spl2
.
set_xlabel
(
"Number of components"
)
spl2
.
set_xlabel
(
"Number of components"
)
spl2
.
set_ylabel
(
"BIC value"
)
spl2
.
set_ylabel
(
"BIC value"
)
plt
.
tight_layout
()
if
save
:
if
save
:
plt
.
savefig
(
save
)
plt
.
savefig
(
save
)
plt
.
show
()
return
modelcomp
tests/test_model_utils.py
View file @
db3e8424
...
@@ -14,4 +14,4 @@ from hypothesis import strategies as st
...
@@ -14,4 +14,4 @@ from hypothesis import strategies as st
from
collections
import
defaultdict
from
collections
import
defaultdict
from
deepof.utils
import
*
from
deepof.utils
import
*
import
deepof.preprocess
import
deepof.preprocess
import
pytest
import
pytest
\ No newline at end of file
tests/test_preprocess.py
View file @
db3e8424
...
@@ -32,7 +32,7 @@ def test_project_init(table_type, arena_type):
...
@@ -32,7 +32,7 @@ def test_project_init(table_type, arena_type):
prun
=
deepof
.
preprocess
.
project
(
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
arena_type
,
arena
=
arena_type
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
angles
=
False
,
angles
=
False
,
video_format
=
".mp4"
,
video_format
=
".mp4"
,
table_format
=
table_type
,
table_format
=
table_type
,
...
@@ -41,7 +41,7 @@ def test_project_init(table_type, arena_type):
...
@@ -41,7 +41,7 @@ def test_project_init(table_type, arena_type):
prun
=
deepof
.
preprocess
.
project
(
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
arena_type
,
arena
=
arena_type
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
angles
=
False
,
angles
=
False
,
video_format
=
".mp4"
,
video_format
=
".mp4"
,
table_format
=
table_type
,
table_format
=
table_type
,
...
@@ -72,7 +72,7 @@ def test_get_distances(nodes, ego):
...
@@ -72,7 +72,7 @@ def test_get_distances(nodes, ego):
prun
=
deepof
.
preprocess
.
project
(
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
angles
=
False
,
angles
=
False
,
video_format
=
".mp4"
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
table_format
=
".h5"
,
...
@@ -98,7 +98,7 @@ def test_get_angles(nodes, ego):
...
@@ -98,7 +98,7 @@ def test_get_angles(nodes, ego):
prun
=
deepof
.
preprocess
.
project
(
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
video_format
=
".mp4"
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
table_format
=
".h5"
,
distances
=
nodes
,
distances
=
nodes
,
...
@@ -123,7 +123,7 @@ def test_run(nodes, ego):
...
@@ -123,7 +123,7 @@ def test_run(nodes, ego):
prun
=
deepof
.
preprocess
.
project
(
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
video_format
=
".mp4"
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
table_format
=
".h5"
,
distances
=
nodes
,
distances
=
nodes
,
...
@@ -147,7 +147,7 @@ def test_get_table_dicts(nodes, ego, sampler):
...
@@ -147,7 +147,7 @@ def test_get_table_dicts(nodes, ego, sampler):
prun
=
deepof
.
preprocess
.
project
(
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena
=
"circular"
,
arena_dims
=
[
380
],
arena_dims
=
tuple
(
[
380
]
)
,
video_format
=
".mp4"
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
table_format
=
".h5"
,
distances
=
nodes
,
distances
=
nodes
,
...
...
tests/test_visuals.py
View file @
db3e8424
...
@@ -11,7 +11,63 @@ Testing module for deepof.visuals
...
@@ -11,7 +11,63 @@ Testing module for deepof.visuals
from
hypothesis
import
given
from
hypothesis
import
given
from
hypothesis
import
settings
from
hypothesis
import
settings
from
hypothesis
import
strategies
as
st
from
hypothesis
import
strategies
as
st
from
collections
import
defaultdict
from
deepof.utils
import
*
from
deepof.utils
import
*
import
deepof.preprocess
import
deepof.preprocess
import
pytest
import
deepof.visuals
\ No newline at end of file
import
matplotlib.figure
def
test_plot_heatmap
():
prun
=
(
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
tuple
([
380
]),
angles
=
False
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
)
.
run
()
.
get_coords
()
)
assert
(
type
(
deepof
.
visuals
.
plot_heatmap
(
prun
[
"test"
],
[
"Center"
],
tuple
([
-
100
,
100
]),
tuple
([
-
100
,
100
]),
dpi
=
200
,
)
)
==
matplotlib
.
figure
.
Figure
)
def
test_model_comparison_plot
():
prun
=
(
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
arena
=
"circular"
,
arena_dims
=
tuple
([
380
]),
angles
=
False
,
video_format
=
".mp4"
,
table_format
=
".h5"
,
)
.
run
()
.
get_coords
()
)
gmm_run
=
gmm_model_selection
(
prun
[
"test"
],
n_components_range
=
range
(
1
,
3
),
n_runs
=
1
,
part_size
=
100
)
assert
(
type
(
deepof
.
visuals
.
model_comparison_plot
(
gmm_run
[
0
],
gmm_run
[
1
],
range
(
1
,
3
),
cov_plot
=
"full"
)
)
==
matplotlib
.
figure
.
Figure
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment