Commit 1c4b1a44 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for rule_based_tagging

parent a93b3a26
......@@ -413,6 +413,9 @@ def recognize_arena(
fnum += 1
cap.release()
cv2.destroyAllWindows()
return arena, h, w
......@@ -481,7 +484,7 @@ def climb_wall(
def rolling_speed(
dframe: pd.DatetimeIndex, window: int = 10, rounds: int = 10, deriv: int = 1
dframe: pd.DatetimeIndex, window: int = 5, rounds: int = 10, deriv: int = 1
) -> pd.DataFrame:
"""Returns the average speed over n frames in pixels per frame
......@@ -894,12 +897,11 @@ def rule_based_tagging(
videos: List,
coordinates: Coordinates,
vid_index: int,
arena_abs: int,
animal_ids: List = None,
show: bool = False,
save: bool = False,
fps: float = 25.0,
speed_pause: int = 50,
speed_pause: int = 10,
frame_limit: float = np.inf,
recog_limit: int = 1,
path: str = os.path.join("./"),
......@@ -911,7 +913,7 @@ def rule_based_tagging(
follow_tol: int = 20,
huddle_forward: int = 15,
huddle_spine: int = 10,
huddle_speed: int = 5,
huddle_speed: int = 1,
) -> pd.DataFrame:
"""Outputs a dataframe with the motives registered per frame."""
......@@ -919,6 +921,7 @@ def rule_based_tagging(
coords = coordinates.get_coords()[vid_name]
speeds = coordinates.get_coords(speed=1)[vid_name]
arena_abs = coordinates.get_arenas[1][0]
arena, h, w = recognize_arena(videos, vid_index, path, recog_limit, arena_type)
# Dictionary with motives per frame
......@@ -997,7 +1000,7 @@ def rule_based_tagging(
pd.Series(
(
spatial.distance.cdist(
np.array(coords[_id + "_Nose"]), np.array([arena[:2]])
np.array(coords[_id + "_Nose"]), np.zeros([1,2])
)
> (w / 200 + arena[2])
).reshape(coords.shape[0]),
......@@ -1005,208 +1008,212 @@ def rule_based_tagging(
).astype(bool)
)
tag_dict[_id + "_speed"] = speeds[_id + "_speed"]
tag_dict[_id + "_huddle"] = smooth_boolean_array(
huddle(coords, speeds, huddle_forward, huddle_spine, huddle_speed)
)
else:
tag_dict["climbing"] = smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(coords["Nose"]), np.array([arena[:2]])
)
spatial.distance.cdist(np.array(coords["Nose"]), np.zeros([1,2]))
> (w / 200 + arena[2])
).reshape(coords.shape[0]),
index=coords.index,
).astype(bool)
)
tag_dict["speed"] = speeds["Center"]
tag_dict["huddle"] = smooth_boolean_array(
huddle(coords, speeds, huddle_forward, huddle_spine, huddle_speed)
)
if classifiers and "huddle" in classifiers:
mouse_X = {
_id: np.array(
coords[vid_name][
[
j
for j in coords[vid_name].keys()
if (len(j) == 2 and _id in j[0] and _id in j[1])
]
]
)
for _id in animal_ids
}
for _id in animal_ids:
tag_dict[_id + "_huddle"] = smooth_boolean_array(
classifiers["huddle"].predict(mouse_X[_id])
)
else:
try:
for _id in animal_ids:
tag_dict[_id + "_huddle"] = smooth_boolean_array(
huddle(coords, speeds, huddle_forward, huddle_spine, huddle_speed)
if any([show, save]):
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
# Keep track of the frame number, to align with the tracking data
fnum = 0
writer = None
frame_speeds = {_id: -np.inf for _id in animal_ids} if animal_ids else -np.inf
# Loop over the frames in the video
pbar = tqdm(total=min(coords.shape[0] - recog_limit, frame_limit))
while cap.isOpened() and fnum < frame_limit:
ret, frame = cap.read()
# if frame is read correctly ret is True
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break
font = cv2.FONT_HERSHEY_COMPLEX_SMALL
# Label positions
downleft = (int(w * 0.3 / 10), int(h / 1.05))
downright = (int(w * 6.5 / 10), int(h / 1.05))
upleft = (int(w * 0.3 / 10), int(h / 20))
upright = (int(w * 6.3 / 10), int(h / 20))
# Capture speeds
try:
if list(frame_speeds.values())[0] == -np.inf or fnum % speed_pause == 0:
for _id in animal_ids:
frame_speeds[_id] = speeds[_id + "_Center"][fnum]
except AttributeError:
if frame_speeds == -np.inf or fnum % speed_pause == 0:
frame_speeds = speeds["Center"][fnum]
# Display all annotations in the output video
if animal_ids:
if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
cv2.putText(
frame,
"Nose-Nose",
(
downleft
if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
else downright
),
font,
1,
(255, 255, 255),
2,
)
if (
tag_dict[animal_ids[0] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
):
cv2.putText(
frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2
)
if (
tag_dict[animal_ids[1] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
):
cv2.putText(
frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2
)
if tag_dict["sidebyside"][fnum]:
cv2.putText(
frame,
"Side-side",
(
downleft
if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
else downright
),
font,
1,
(255, 255, 255),
2,
)
if tag_dict["sidereside"][fnum]:
cv2.putText(
frame,
"Side-Rside",
(
downleft
if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
else downright
),
font,
1,
(255, 255, 255),
2,
)
for _id, down_pos, up_pos in zip(
animal_ids, [downleft, downright], [upleft, upright]
):
if tag_dict[_id + "_climbing"][fnum]:
cv2.putText(
frame, "Climbing", down_pos, font, 1, (255, 255, 255), 2
)
if (
tag_dict[_id + "_huddle"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
):
cv2.putText(
frame, "Huddling", down_pos, font, 1, (255, 255, 255), 2
)
if (
tag_dict[_id + "_following"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
):
cv2.putText(
frame,
"*f",
(int(w * 0.3 / 10), int(h / 10)),
font,
1,
(
(150, 150, 255)
if frame_speeds[animal_ids[0]]
> frame_speeds[animal_ids[1]]
else (150, 255, 150)
),
2,
)
cv2.putText(
frame,
_id + ": " + str(np.round(frame_speeds[_id], 2)) + " mmpf",
(up_pos[0] - 20, up_pos[1]),
font,
1,
(
(150, 150, 255)
if frame_speeds[_id] == max(list(frame_speeds.values()))
else (150, 255, 150)
),
2,
)
else:
if tag_dict["climbing"][fnum]:
cv2.putText(
frame, "Climbing", downleft, font, 1, (255, 255, 255), 2
)
if tag_dict["huddle"][fnum] and not tag_dict["climbing"][fnum]:
cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2)
cv2.putText(
frame,
str(np.round(frame_speeds, 2)) + " mmpf",
upleft,
font,
1,
(
(150, 150, 255)
if huddle_speed > frame_speeds
else (150, 255, 150)
),
2,
)
except TypeError:
tag_dict["huddle"] = smooth_boolean_array(
huddle(coords, speeds, huddle_forward, huddle_spine, huddle_speed)
)
# if any([show, save]):
# cap = cv2.VideoCapture(path + videos[vid_index])
#
# # Keep track of the frame number, to align with the tracking data
# fnum = 0
# if save:
# writer = None
#
# # Loop over the frames in the video
# pbar = tqdm(total=min(coords.shape[0] - recog_limit, frame_limit))
# while cap.isOpened() and fnum < frame_limit:
#
# ret, frame = cap.read()
# # if frame is read correctly ret is True
# if not ret:
# print("Can't receive frame (stream end?). Exiting ...")
# break
#
# font = cv2.FONT_HERSHEY_COMPLEX_SMALL
#
# if like_qc_dict[vid_name][fnum]:
#
# # Extract positions
# pos_dict = {
# i: np.array([coords[i]["x"][fnum], coords[i]["y"][fnum]])
# for i in coords.columns.levels[0]
# if i != "Like_QC"
# }
#
# if h is None and w is None:
# h, w = frame.shape[0], frame.shape[1]
#
# # Label positions
# downleft = (int(w * 0.3 / 10), int(h / 1.05))
# downright = (int(w * 6.5 / 10), int(h / 1.05))
# upleft = (int(w * 0.3 / 10), int(h / 20))
# upright = (int(w * 6.3 / 10), int(h / 20))
#
# # Display all annotations in the output video
# if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
# cv2.putText(
# frame,
# "Nose-Nose",
# (downleft if bspeed > wspeed else downright),
# font,
# 1,
# (255, 255, 255),
# 2,
# )
# if tag_dict["bnose2tail"][fnum] and not tag_dict["sidereside"][fnum]:
# cv2.putText(
# frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2
# )
# if tag_dict["wnose2tail"][fnum] and not tag_dict["sidereside"][fnum]:
# cv2.putText(
# frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2
# )
# if tag_dict["sidebyside"][fnum]:
# cv2.putText(
# frame,
# "Side-side",
# (downleft if bspeed > wspeed else downright),
# font,
# 1,
# (255, 255, 255),
# 2,
# )
# if tag_dict["sidereside"][fnum]:
# cv2.putText(
# frame,
# "Side-Rside",
# (downleft if bspeed > wspeed else downright),
# font,
# 1,
# (255, 255, 255),
# 2,
# )
# if tag_dict["bclimbwall"][fnum]:
# cv2.putText(
# frame, "Climbing", downleft, font, 1, (255, 255, 255), 2
# )
# if tag_dict["wclimbwall"][fnum]:
# cv2.putText(
# frame, "Climbing", downright, font, 1, (255, 255, 255), 2
# )
# if tag_dict["bhuddle"][fnum] and not tag_dict["bclimbwall"][fnum]:
# cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2)
# if tag_dict["whuddle"][fnum] and not tag_dict["wclimbwall"][fnum]:
# cv2.putText(frame, "huddle", downright, font, 1, (255, 255, 255), 2)
# if tag_dict["bfollowing"][fnum] and not tag_dict["bclimbwall"][fnum]:
# cv2.putText(
# frame,
# "*f",
# (int(w * 0.3 / 10), int(h / 10)),
# font,
# 1,
# ((150, 150, 255) if wspeed > bspeed else (150, 255, 150)),
# 2,
# )
# if tag_dict["wfollowing"][fnum] and not tag_dict["wclimbwall"][fnum]:
# cv2.putText(
# frame,
# "*f",
# (int(w * 6.3 / 10), int(h / 10)),
# font,
# 1,
# ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)),
# 2,
# )
#
# if (bspeed == None and wspeed == None) or fnum % speed_pause == 0:
# bspeed = tag_dict["bspeed"][fnum]
# wspeed = tag_dict["wspeed"][fnum]
#
# cv2.putText(
# frame,
# "W: " + str(np.round(wspeed, 2)) + " mmpf",
# (upright[0] - 20, upright[1]),
# font,
# 1,
# ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)),
# 2,
# )
# cv2.putText(
# frame,
# "B: " + str(np.round(bspeed, 2)) + " mmpf",
# upleft,
# font,
# 1,
# ((150, 150, 255) if bspeed < wspeed else (150, 255, 150)),
# 2,
# )
#
# if show:
# cv2.imshow("frame", frame)
#
# if save:
#
# if writer is None:
# # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
# # Define the FPS. Also frame size is passed.
# writer = cv2.VideoWriter()
# writer.open(
# re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
# cv2.VideoWriter_fourcc(*"MJPG"),
# fps,
# (frame.shape[1], frame.shape[0]),
# True,
# )
# writer.write(frame)
#
# if cv2.waitKey(1) == ord("q"):
# break
#
# pbar.update(1)
# fnum += 1
#
# cap.release()
# cv2.destroyAllWindows()
if show:
cv2.imshow("frame", frame)
if cv2.waitKey(1) == ord("q"):
break
if save:
if writer is None:
# Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
# Define the FPS. Also frame size is passed.
writer = cv2.VideoWriter()
writer.open(
re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
cv2.VideoWriter_fourcc(*"MJPG"),
fps,
(frame.shape[1], frame.shape[0]),
True,
)
writer.write(frame)
pbar.update(1)
fnum += 1
cap.release()
cv2.destroyAllWindows()
tag_df = pd.DataFrame(tag_dict)
......@@ -1218,7 +1225,3 @@ def rule_based_tagging(
# - Add digging to rule_based_tagging
# - Add center to rule_based_tagging
# - Check for features requested by Joeri
# - Check speed. Avoid recomputing unnecessarily
# - Pass thresholds as parameters of the function. Provide defaults (we should tune them in the future)
# - Check if attributes I'm asking for (eg arena) are already stored in Table_dict metadata
......@@ -754,9 +754,7 @@ def test_cluster_transition_matrix(sampler, autocorrelation, return_graph):
assert type(trans) == np.ndarray
@settings(deadline=None)
@given(sampler=st.data())
def test_rule_based_tagging(sampler):
def test_rule_based_tagging():
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
......@@ -772,8 +770,9 @@ def test_rule_based_tagging(sampler):
["test_video_circular_arena.mp4"],
prun,
vid_index=0,
arena_abs=380,
path=os.path.join(".", "tests", "test_examples", "Videos"),
save=True,
frame_limit=100,
)
assert type(hardcoded_tags) == pd.DataFrame
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment