Commit d8470aa4 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added testing examples for multi animal deepof pipeline

parent 2abbb664
Pipeline #91825 failed with stage
in 23 minutes and 7 seconds
......@@ -381,10 +381,10 @@ def get_hparameters(hparams: dict = {}) -> dict:
defaults = {
"speed_pause": 3,
"close_contact_tol": 15,
"side_contact_tol": 15,
"follow_frames": 20,
"follow_tol": 20,
"close_contact_tol": 35,
"side_contact_tol": 80,
"follow_frames": 10,
"follow_tol": 5,
"huddle_forward": 15,
"huddle_spine": 10,
"huddle_speed": 0.1,
......@@ -527,10 +527,12 @@ def rule_based_tagging(
tag_dict["sidereside"] = twobytwo_contact(rev=True)
for i, _id in enumerate(animal_ids):
tag_dict[_id + "_nose2tail"] = onebyone_contact(
bparts=["_Nose", "_Tail_base"]
)
tag_dict[animal_ids[0] + "_nose2tail"] = onebyone_contact(
bparts=["_Nose", "_Tail_base"]
)
tag_dict[animal_ids[1] + "_nose2tail"] = onebyone_contact(
bparts=["_Tail_base", "_Nose"]
)
for _id in animal_ids:
tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
......@@ -577,7 +579,7 @@ def tag_rulebased_frames(
undercond,
hparams,
):
"""Helper function for rule_based_video. Annotates a fiven frame with on-screen information
"""Helper function for rule_based_video. Annotates a given frame with on-screen information
about the recognised patterns"""
w, h = dims
......@@ -612,6 +614,8 @@ def tag_rulebased_frames(
if len(animal_ids) > 1:
cv2.line(frame, (100, 100), (110, 100), (255, 0, 0))
if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
write_on_frame("Nose-Nose", conditional_pos())
if (
......@@ -634,16 +638,16 @@ def tag_rulebased_frames(
"Side-Rside",
conditional_pos(),
)
for _id, down_pos, up_pos in zipped_pos:
if (
tag_dict[_id + "_following"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
):
write_on_frame(
"*f",
(int(w * 0.3 / 10), int(h / 10)),
conditional_col(),
)
# for _id, down_pos, up_pos in zipped_pos:
# if (
# tag_dict[_id + "_following"][fnum]
# and not tag_dict[_id + "_climbing"][fnum]
# ):
# write_on_frame(
# "*f",
# (int(w * 0.3 / 10), int(h / 10)),
# conditional_col(),
# )
for _id, down_pos, up_pos in zipped_pos:
......@@ -792,3 +796,8 @@ def rule_based_video(
cv2.destroyAllWindows()
return True
# TODO:
# - Relativise default contact parameters
# (right now they only work if 'arena_dims' is correctly set in the project object
......@@ -389,3 +389,7 @@ def test_rule_based_tagging(multi_animal, video_output):
assert type(hardcoded_tags) == deepof.data.table_dict
assert list(hardcoded_tags.values())[0].shape[1] == (13 if multi_animal else 3)
# TODO:
# - Test if tagging is working properly!
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment