Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
1e83611f
Commit
1e83611f
authored
Sep 14, 2020
by
lucas_miranda
Browse files
Refactored functions in visuals.py
parent
0587b97a
Changes
1
Hide whitespace changes
Inline
Side-by-side
deepof/visuals.py
View file @
1e83611f
...
...
@@ -5,19 +5,20 @@ import numpy as np
import
pandas
as
pd
import
seaborn
as
sns
from
itertools
import
cycle
from
typing
import
List
,
Dict
# PLOTTING FUNCTIONS #
def
plot_speed
(
B
ehaviour_dict
,
T
reatments
)
:
def
plot_speed
(
b
ehaviour_
dict
:
dict
,
t
reatments
:
Dict
[
List
])
->
plt
.
figure
:
"""Plots a histogram with the speed of the specified mouse.
Treatments is expected to be a list of lists with mice keys per treatment"""
fig
,
[
ax1
,
ax2
]
=
plt
.
subplots
(
1
,
2
,
figsize
=
(
20
,
10
))
for
Treatment
,
Mice_list
in
T
reatments
.
items
():
hist
=
pd
.
concat
([
B
ehaviour_dict
[
mouse
]
for
mouse
in
Mice_list
])
for
Treatment
,
Mice_list
in
t
reatments
.
items
():
hist
=
pd
.
concat
([
b
ehaviour_dict
[
mouse
]
for
mouse
in
Mice_list
])
sns
.
kdeplot
(
hist
[
"bspeed"
],
shade
=
True
,
label
=
Treatment
,
ax
=
ax1
)
sns
.
kdeplot
(
hist
[
"wspeed"
],
shade
=
True
,
label
=
Treatment
,
ax
=
ax2
)
...
...
@@ -30,10 +31,13 @@ def plot_speed(Behaviour_dict, Treatments):
plt
.
show
()
def
plot_heatmap
(
dframe
,
bodyparts
,
xlim
,
ylim
,
save
=
False
):
def
plot_heatmap
(
dframe
:
pd
.
DataFrame
,
bodyparts
:
List
,
xlim
:
float
,
ylim
:
float
,
save
:
str
=
False
)
->
plt
.
figure
:
"""Returns a heatmap of the movement of a specific bodypart in the arena.
If more than one bodypart is passed, it returns one subplot for each"""
# noinspection PyTypeChecker
fig
,
ax
=
plt
.
subplots
(
1
,
len
(
bodyparts
),
sharex
=
True
,
sharey
=
True
)
for
i
,
bpart
in
enumerate
(
bodyparts
):
...
...
@@ -48,26 +52,24 @@ def plot_heatmap(dframe, bodyparts, xlim, ylim, save=False):
[
x
.
set_ylim
(
ylim
)
for
x
in
ax
]
[
x
.
set_title
(
bp
)
for
x
,
bp
in
zip
(
ax
,
bodyparts
)]
if
save
!=
False
:
if
save
:
plt
.
savefig
(
save
)
plt
.
show
()
def
model_comparison_plot
(
bic
,
m_bic
,
best_bic_gmm
,
n_components_range
,
cov_plot
,
save
,
cv_types
=
[
"spherical"
,
"tied"
,
"diag"
,
"full"
],
):
bic
:
list
,
m_bic
:
list
,
n_components_range
:
range
,
cov_plot
:
str
,
save
:
str
,
cv_types
:
tuple
=
(
"spherical"
,
"tied"
,
"diag"
,
"full"
),
)
->
plt
.
figure
:
"""Plots model comparison statistics over all tests"""
m_bic
=
np
.
array
(
m_bic
)
color_iter
=
cycle
([
"navy"
,
"turquoise"
,
"cornflowerblue"
,
"darkorange"
])
clf
=
best_bic_gmm
bars
=
[]
# Plot the BIC scores
...
...
@@ -93,6 +95,7 @@ def model_comparison_plot(
+
0.5
+
0.2
*
np
.
floor
(
m_bic
.
argmin
()
/
len
(
n_components_range
))
)
# noinspection PyArgumentList
spl
.
text
(
xpos
,
m_bic
.
min
()
*
0.97
+
0.1
*
m_bic
.
max
(),
"*"
,
fontsize
=
14
)
spl
.
legend
([
b
[
0
]
for
b
in
bars
],
cv_types
)
spl
.
set_ylabel
(
"BIC value"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment