Commit 3d8d731e authored by lucas_miranda's avatar lucas_miranda
Browse files

Added testing examples for multi animal deepof pipeline

parent 8aa291b1
......@@ -486,6 +486,7 @@ def rule_based_tagging(
def onebyone_contact(bparts: List):
"""Returns a smooth boolean array with 1to1 contacts between two mice"""
nonlocal coords, animal_ids, hparams, arena_abs, arena
return deepof.utils.smooth_boolean_array(
close_single_contact(
coords,
......@@ -524,8 +525,9 @@ def rule_based_tagging(
tag_dict["sidereside"] = twobytwo_contact(rev=True)
for i, _id in enumerate(animal_ids):
bps = [["_Nose", "_Tail_base"], ["_Tail_base", "_Nose"]]
tag_dict[_id + "_nose2tail"] = onebyone_contact(bparts=bps)
tag_dict[_id + "_nose2tail"] = onebyone_contact(
bparts=["_Nose", "_Tail_base"]
)
for _id in animal_ids:
tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
......
......@@ -357,66 +357,35 @@ def test_frame_corners(w, h):
assert frame_corners(w, h, {"downright": "test"})["downright"] == "test"
def test_rule_based_tagging():
prun = deepof.data.project(
path=os.path.join(".", "tests", "test_examples", "test_single_topview"),
arena="circular",
arena_dims=tuple([380]),
video_format=".mp4",
table_format=".h5",
animal_ids=[""],
).run(verbose=True)
hardcoded_tags = rule_based_tagging(
list([i for i in prun.get_coords().keys()]),
["testDLC_video_circular_arena.mp4"],
prun,
prun.get_coords(),
prun.get_coords(speed=1),
arena_type="circular",
vid_index=0,
path=os.path.join(
".", "tests", "test_examples", "test_single_topview", "Videos"
),
)
@settings(deadline=None)
@given(
multi_animal=st.booleans(),
video_output=st.booleans(),
)
def test_rule_based_tagging(multi_animal, video_output):
assert type(hardcoded_tags) == pd.DataFrame
assert hardcoded_tags.shape[1] == 3
multi_animal = False
if video_output:
video_output = ["test"]
def test_rule_based_video():
path = os.path.join(
".",
"tests",
"test_examples",
"test_{}_topview".format("multi" if multi_animal else "single"),
)
prun = deepof.data.project(
path=os.path.join(".", "tests", "test_examples", "test_single_topview"),
path=path,
arena="circular",
arena_dims=tuple([380]),
video_format=".mp4",
table_format=".h5",
animal_ids=[""],
animal_ids=(["B", "W"] if multi_animal else [""]),
).run(verbose=True)
hardcoded_tags = rule_based_tagging(
list([i for i in prun.get_coords().keys()]),
["testDLC_video_circular_arena.mp4"],
prun,
prun.get_coords(),
prun.get_coords(speed=1),
arena_type="circular",
vid_index=0,
path=os.path.join(
".", "tests", "test_examples", "test_single_topview", "Videos"
),
)
hardcoded_tags = prun.rule_based_annotation(video_output=video_output, frame_limit=50)
rule_based_video(
coordinates=prun,
tracks=list([i + "_" for i in prun.get_coords().keys()]),
videos=["testDLC_video_circular_arena.mp4"],
vid_index=0,
frame_limit=100,
tag_dict=hardcoded_tags,
path=os.path.join(
".", "tests", "test_examples", "test_single_topview", "Videos"
),
)
assert type(hardcoded_tags) == deepof.data.table_dict
assert list(hardcoded_tags.values())[0].shape[1] == 3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment