Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
1c4b1a44
Commit
1c4b1a44
authored
Sep 15, 2020
by
lucas_miranda
Browse files
Added tests for rule_based_tagging
parent
a93b3a26
Changes
2
Hide whitespace changes
Inline
Side-by-side
deepof/utils.py
View file @
1c4b1a44
...
...
@@ -413,6 +413,9 @@ def recognize_arena(
fnum
+=
1
cap
.
release
()
cv2
.
destroyAllWindows
()
return
arena
,
h
,
w
...
...
@@ -481,7 +484,7 @@ def climb_wall(
def
rolling_speed
(
dframe
:
pd
.
DatetimeIndex
,
window
:
int
=
10
,
rounds
:
int
=
10
,
deriv
:
int
=
1
dframe
:
pd
.
DatetimeIndex
,
window
:
int
=
5
,
rounds
:
int
=
10
,
deriv
:
int
=
1
)
->
pd
.
DataFrame
:
"""Returns the average speed over n frames in pixels per frame
...
...
@@ -894,12 +897,11 @@ def rule_based_tagging(
videos
:
List
,
coordinates
:
Coordinates
,
vid_index
:
int
,
arena_abs
:
int
,
animal_ids
:
List
=
None
,
show
:
bool
=
False
,
save
:
bool
=
False
,
fps
:
float
=
25.0
,
speed_pause
:
int
=
5
0
,
speed_pause
:
int
=
1
0
,
frame_limit
:
float
=
np
.
inf
,
recog_limit
:
int
=
1
,
path
:
str
=
os
.
path
.
join
(
"./"
),
...
...
@@ -911,7 +913,7 @@ def rule_based_tagging(
follow_tol
:
int
=
20
,
huddle_forward
:
int
=
15
,
huddle_spine
:
int
=
10
,
huddle_speed
:
int
=
5
,
huddle_speed
:
int
=
1
,
)
->
pd
.
DataFrame
:
"""Outputs a dataframe with the motives registered per frame."""
...
...
@@ -919,6 +921,7 @@ def rule_based_tagging(
coords
=
coordinates
.
get_coords
()[
vid_name
]
speeds
=
coordinates
.
get_coords
(
speed
=
1
)[
vid_name
]
arena_abs
=
coordinates
.
get_arenas
[
1
][
0
]
arena
,
h
,
w
=
recognize_arena
(
videos
,
vid_index
,
path
,
recog_limit
,
arena_type
)
# Dictionary with motives per frame
...
...
@@ -997,7 +1000,7 @@ def rule_based_tagging(
pd
.
Series
(
(
spatial
.
distance
.
cdist
(
np
.
array
(
coords
[
_id
+
"_Nose"
]),
np
.
array
([
arena
[:
2
]
])
np
.
array
(
coords
[
_id
+
"_Nose"
]),
np
.
zeros
([
1
,
2
])
)
>
(
w
/
200
+
arena
[
2
])
).
reshape
(
coords
.
shape
[
0
]),
...
...
@@ -1005,208 +1008,212 @@ def rule_based_tagging(
).
astype
(
bool
)
)
tag_dict
[
_id
+
"_speed"
]
=
speeds
[
_id
+
"_speed"
]
tag_dict
[
_id
+
"_huddle"
]
=
smooth_boolean_array
(
huddle
(
coords
,
speeds
,
huddle_forward
,
huddle_spine
,
huddle_speed
)
)
else
:
tag_dict
[
"climbing"
]
=
smooth_boolean_array
(
pd
.
Series
(
(
spatial
.
distance
.
cdist
(
np
.
array
(
coords
[
"Nose"
]),
np
.
array
([
arena
[:
2
]])
)
spatial
.
distance
.
cdist
(
np
.
array
(
coords
[
"Nose"
]),
np
.
zeros
([
1
,
2
]))
>
(
w
/
200
+
arena
[
2
])
).
reshape
(
coords
.
shape
[
0
]),
index
=
coords
.
index
,
).
astype
(
bool
)
)
tag_dict
[
"speed"
]
=
speeds
[
"Center"
]
tag_dict
[
"huddle"
]
=
smooth_boolean_array
(
huddle
(
coords
,
speeds
,
huddle_forward
,
huddle_spine
,
huddle_speed
)
)
if
classifiers
and
"huddle"
in
classifiers
:
mouse_X
=
{
_id
:
np
.
array
(
coords
[
vid_name
][
[
j
for
j
in
coords
[
vid_name
].
keys
()
if
(
len
(
j
)
==
2
and
_id
in
j
[
0
]
and
_id
in
j
[
1
])
]
]
)
for
_id
in
animal_ids
}
for
_id
in
animal_ids
:
tag_dict
[
_id
+
"_huddle"
]
=
smooth_boolean_array
(
classifiers
[
"huddle"
].
predict
(
mouse_X
[
_id
])
)
else
:
try
:
for
_id
in
animal_ids
:
tag_dict
[
_id
+
"_huddle"
]
=
smooth_boolean_array
(
huddle
(
coords
,
speeds
,
huddle_forward
,
huddle_spine
,
huddle_speed
)
if
any
([
show
,
save
]):
cap
=
cv2
.
VideoCapture
(
os
.
path
.
join
(
path
,
videos
[
vid_index
]))
# Keep track of the frame number, to align with the tracking data
fnum
=
0
writer
=
None
frame_speeds
=
{
_id
:
-
np
.
inf
for
_id
in
animal_ids
}
if
animal_ids
else
-
np
.
inf
# Loop over the frames in the video
pbar
=
tqdm
(
total
=
min
(
coords
.
shape
[
0
]
-
recog_limit
,
frame_limit
))
while
cap
.
isOpened
()
and
fnum
<
frame_limit
:
ret
,
frame
=
cap
.
read
()
# if frame is read correctly ret is True
if
not
ret
:
print
(
"Can't receive frame (stream end?). Exiting ..."
)
break
font
=
cv2
.
FONT_HERSHEY_COMPLEX_SMALL
# Label positions
downleft
=
(
int
(
w
*
0.3
/
10
),
int
(
h
/
1.05
))
downright
=
(
int
(
w
*
6.5
/
10
),
int
(
h
/
1.05
))
upleft
=
(
int
(
w
*
0.3
/
10
),
int
(
h
/
20
))
upright
=
(
int
(
w
*
6.3
/
10
),
int
(
h
/
20
))
# Capture speeds
try
:
if
list
(
frame_speeds
.
values
())[
0
]
==
-
np
.
inf
or
fnum
%
speed_pause
==
0
:
for
_id
in
animal_ids
:
frame_speeds
[
_id
]
=
speeds
[
_id
+
"_Center"
][
fnum
]
except
AttributeError
:
if
frame_speeds
==
-
np
.
inf
or
fnum
%
speed_pause
==
0
:
frame_speeds
=
speeds
[
"Center"
][
fnum
]
# Display all annotations in the output video
if
animal_ids
:
if
tag_dict
[
"nose2nose"
][
fnum
]
and
not
tag_dict
[
"sidebyside"
][
fnum
]:
cv2
.
putText
(
frame
,
"Nose-Nose"
,
(
downleft
if
frame_speeds
[
animal_ids
[
0
]]
>
frame_speeds
[
animal_ids
[
1
]]
else
downright
),
font
,
1
,
(
255
,
255
,
255
),
2
,
)
if
(
tag_dict
[
animal_ids
[
0
]
+
"_nose2tail"
][
fnum
]
and
not
tag_dict
[
"sidereside"
][
fnum
]
):
cv2
.
putText
(
frame
,
"Nose-Tail"
,
downleft
,
font
,
1
,
(
255
,
255
,
255
),
2
)
if
(
tag_dict
[
animal_ids
[
1
]
+
"_nose2tail"
][
fnum
]
and
not
tag_dict
[
"sidereside"
][
fnum
]
):
cv2
.
putText
(
frame
,
"Nose-Tail"
,
downright
,
font
,
1
,
(
255
,
255
,
255
),
2
)
if
tag_dict
[
"sidebyside"
][
fnum
]:
cv2
.
putText
(
frame
,
"Side-side"
,
(
downleft
if
frame_speeds
[
animal_ids
[
0
]]
>
frame_speeds
[
animal_ids
[
1
]]
else
downright
),
font
,
1
,
(
255
,
255
,
255
),
2
,
)
if
tag_dict
[
"sidereside"
][
fnum
]:
cv2
.
putText
(
frame
,
"Side-Rside"
,
(
downleft
if
frame_speeds
[
animal_ids
[
0
]]
>
frame_speeds
[
animal_ids
[
1
]]
else
downright
),
font
,
1
,
(
255
,
255
,
255
),
2
,
)
for
_id
,
down_pos
,
up_pos
in
zip
(
animal_ids
,
[
downleft
,
downright
],
[
upleft
,
upright
]
):
if
tag_dict
[
_id
+
"_climbing"
][
fnum
]:
cv2
.
putText
(
frame
,
"Climbing"
,
down_pos
,
font
,
1
,
(
255
,
255
,
255
),
2
)
if
(
tag_dict
[
_id
+
"_huddle"
][
fnum
]
and
not
tag_dict
[
_id
+
"_climbing"
][
fnum
]
):
cv2
.
putText
(
frame
,
"Huddling"
,
down_pos
,
font
,
1
,
(
255
,
255
,
255
),
2
)
if
(
tag_dict
[
_id
+
"_following"
][
fnum
]
and
not
tag_dict
[
_id
+
"_climbing"
][
fnum
]
):
cv2
.
putText
(
frame
,
"*f"
,
(
int
(
w
*
0.3
/
10
),
int
(
h
/
10
)),
font
,
1
,
(
(
150
,
150
,
255
)
if
frame_speeds
[
animal_ids
[
0
]]
>
frame_speeds
[
animal_ids
[
1
]]
else
(
150
,
255
,
150
)
),
2
,
)
cv2
.
putText
(
frame
,
_id
+
": "
+
str
(
np
.
round
(
frame_speeds
[
_id
],
2
))
+
" mmpf"
,
(
up_pos
[
0
]
-
20
,
up_pos
[
1
]),
font
,
1
,
(
(
150
,
150
,
255
)
if
frame_speeds
[
_id
]
==
max
(
list
(
frame_speeds
.
values
()))
else
(
150
,
255
,
150
)
),
2
,
)
else
:
if
tag_dict
[
"climbing"
][
fnum
]:
cv2
.
putText
(
frame
,
"Climbing"
,
downleft
,
font
,
1
,
(
255
,
255
,
255
),
2
)
if
tag_dict
[
"huddle"
][
fnum
]
and
not
tag_dict
[
"climbing"
][
fnum
]:
cv2
.
putText
(
frame
,
"huddle"
,
downleft
,
font
,
1
,
(
255
,
255
,
255
),
2
)
cv2
.
putText
(
frame
,
str
(
np
.
round
(
frame_speeds
,
2
))
+
" mmpf"
,
upleft
,
font
,
1
,
(
(
150
,
150
,
255
)
if
huddle_speed
>
frame_speeds
else
(
150
,
255
,
150
)
),
2
,
)
except
TypeError
:
tag_dict
[
"huddle"
]
=
smooth_boolean_array
(
huddle
(
coords
,
speeds
,
huddle_forward
,
huddle_spine
,
huddle_speed
)
)
# if any([show, save]):
# cap = cv2.VideoCapture(path + videos[vid_index])
#
# # Keep track of the frame number, to align with the tracking data
# fnum = 0
# if save:
# writer = None
#
# # Loop over the frames in the video
# pbar = tqdm(total=min(coords.shape[0] - recog_limit, frame_limit))
# while cap.isOpened() and fnum < frame_limit:
#
# ret, frame = cap.read()
# # if frame is read correctly ret is True
# if not ret:
# print("Can't receive frame (stream end?). Exiting ...")
# break
#
# font = cv2.FONT_HERSHEY_COMPLEX_SMALL
#
# if like_qc_dict[vid_name][fnum]:
#
# # Extract positions
# pos_dict = {
# i: np.array([coords[i]["x"][fnum], coords[i]["y"][fnum]])
# for i in coords.columns.levels[0]
# if i != "Like_QC"
# }
#
# if h is None and w is None:
# h, w = frame.shape[0], frame.shape[1]
#
# # Label positions
# downleft = (int(w * 0.3 / 10), int(h / 1.05))
# downright = (int(w * 6.5 / 10), int(h / 1.05))
# upleft = (int(w * 0.3 / 10), int(h / 20))
# upright = (int(w * 6.3 / 10), int(h / 20))
#
# # Display all annotations in the output video
# if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
# cv2.putText(
# frame,
# "Nose-Nose",
# (downleft if bspeed > wspeed else downright),
# font,
# 1,
# (255, 255, 255),
# 2,
# )
# if tag_dict["bnose2tail"][fnum] and not tag_dict["sidereside"][fnum]:
# cv2.putText(
# frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2
# )
# if tag_dict["wnose2tail"][fnum] and not tag_dict["sidereside"][fnum]:
# cv2.putText(
# frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2
# )
# if tag_dict["sidebyside"][fnum]:
# cv2.putText(
# frame,
# "Side-side",
# (downleft if bspeed > wspeed else downright),
# font,
# 1,
# (255, 255, 255),
# 2,
# )
# if tag_dict["sidereside"][fnum]:
# cv2.putText(
# frame,
# "Side-Rside",
# (downleft if bspeed > wspeed else downright),
# font,
# 1,
# (255, 255, 255),
# 2,
# )
# if tag_dict["bclimbwall"][fnum]:
# cv2.putText(
# frame, "Climbing", downleft, font, 1, (255, 255, 255), 2
# )
# if tag_dict["wclimbwall"][fnum]:
# cv2.putText(
# frame, "Climbing", downright, font, 1, (255, 255, 255), 2
# )
# if tag_dict["bhuddle"][fnum] and not tag_dict["bclimbwall"][fnum]:
# cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2)
# if tag_dict["whuddle"][fnum] and not tag_dict["wclimbwall"][fnum]:
# cv2.putText(frame, "huddle", downright, font, 1, (255, 255, 255), 2)
# if tag_dict["bfollowing"][fnum] and not tag_dict["bclimbwall"][fnum]:
# cv2.putText(
# frame,
# "*f",
# (int(w * 0.3 / 10), int(h / 10)),
# font,
# 1,
# ((150, 150, 255) if wspeed > bspeed else (150, 255, 150)),
# 2,
# )
# if tag_dict["wfollowing"][fnum] and not tag_dict["wclimbwall"][fnum]:
# cv2.putText(
# frame,
# "*f",
# (int(w * 6.3 / 10), int(h / 10)),
# font,
# 1,
# ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)),
# 2,
# )
#
# if (bspeed == None and wspeed == None) or fnum % speed_pause == 0:
# bspeed = tag_dict["bspeed"][fnum]
# wspeed = tag_dict["wspeed"][fnum]
#
# cv2.putText(
# frame,
# "W: " + str(np.round(wspeed, 2)) + " mmpf",
# (upright[0] - 20, upright[1]),
# font,
# 1,
# ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)),
# 2,
# )
# cv2.putText(
# frame,
# "B: " + str(np.round(bspeed, 2)) + " mmpf",
# upleft,
# font,
# 1,
# ((150, 150, 255) if bspeed < wspeed else (150, 255, 150)),
# 2,
# )
#
# if show:
# cv2.imshow("frame", frame)
#
# if save:
#
# if writer is None:
# # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
# # Define the FPS. Also frame size is passed.
# writer = cv2.VideoWriter()
# writer.open(
# re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
# cv2.VideoWriter_fourcc(*"MJPG"),
# fps,
# (frame.shape[1], frame.shape[0]),
# True,
# )
# writer.write(frame)
#
# if cv2.waitKey(1) == ord("q"):
# break
#
# pbar.update(1)
# fnum += 1
#
# cap.release()
# cv2.destroyAllWindows()
if
show
:
cv2
.
imshow
(
"frame"
,
frame
)
if
cv2
.
waitKey
(
1
)
==
ord
(
"q"
):
break
if
save
:
if
writer
is
None
:
# Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
# Define the FPS. Also frame size is passed.
writer
=
cv2
.
VideoWriter
()
writer
.
open
(
re
.
findall
(
"(.*?)_"
,
tracks
[
vid_index
])[
0
]
+
"_tagged.avi"
,
cv2
.
VideoWriter_fourcc
(
*
"MJPG"
),
fps
,
(
frame
.
shape
[
1
],
frame
.
shape
[
0
]),
True
,
)
writer
.
write
(
frame
)
pbar
.
update
(
1
)
fnum
+=
1
cap
.
release
()
cv2
.
destroyAllWindows
()
tag_df
=
pd
.
DataFrame
(
tag_dict
)
...
...
@@ -1218,7 +1225,3 @@ def rule_based_tagging(
# - Add digging to rule_based_tagging
# - Add center to rule_based_tagging
# - Check for features requested by Joeri
# - Check speed. Avoid recomputing unnecessarily
# - Pass thresholds as parameters of the function. Provide defaults (we should tune them in the future)
# - Check if attributes I'm asking for (eg arena) are already stored in Table_dict metadata
tests/test_utils.py
View file @
1c4b1a44
...
...
@@ -754,9 +754,7 @@ def test_cluster_transition_matrix(sampler, autocorrelation, return_graph):
assert
type
(
trans
)
==
np
.
ndarray
@
settings
(
deadline
=
None
)
@
given
(
sampler
=
st
.
data
())
def
test_rule_based_tagging
(
sampler
):
def
test_rule_based_tagging
():
prun
=
deepof
.
preprocess
.
project
(
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
),
...
...
@@ -772,8 +770,9 @@ def test_rule_based_tagging(sampler):
[
"test_video_circular_arena.mp4"
],
prun
,
vid_index
=
0
,
arena_abs
=
380
,
path
=
os
.
path
.
join
(
"."
,
"tests"
,
"test_examples"
,
"Videos"
),
save
=
True
,
frame_limit
=
100
,
)
assert
type
(
hardcoded_tags
)
==
pd
.
DataFrame
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment