From b2f4bbb64a71b2aabcc257ed16b8794e325f4ffd Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Mon, 18 Jan 2021 23:42:03 +0100 Subject: [PATCH] Added testing examples for multi animal deepof pipeline --- deepof/pose_utils.py | 12 ++++++++---- tests/test_pose_utils.py | 4 +++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py index 40b7ad59..f5016d8d 100644 --- a/deepof/pose_utils.py +++ b/deepof/pose_utils.py @@ -556,7 +556,7 @@ def rule_based_tagging( hparams["huddle_forward"], hparams["huddle_spine"], hparams["huddle_speed"], - animal_id=_id + animal_id=_id, ) ) @@ -602,13 +602,14 @@ def tag_rulebased_frames( else: return 150, 255, 150 - zipped_pos = zip( + zipped_pos = list(zip( animal_ids, [corners["downleft"], corners["downright"]], [corners["upleft"], corners["upright"]], - ) + )) if len(animal_ids) > 1: + if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]: write_on_frame("Nose-Nose", conditional_pos()) if ( @@ -659,7 +660,10 @@ def tag_rulebased_frames( colcond = hparams["huddle_speed"] > frame_speeds write_on_frame( - str(np.round(frame_speeds, 2)) + " mmpf", + str( + np.round((frame_speeds if len(animal_ids) == 1 else frame_speeds[_id]), 2) + ) + + " mmpf", up_pos, conditional_col(cond=colcond), ) diff --git a/tests/test_pose_utils.py b/tests/test_pose_utils.py index 6f948a18..574f5e66 100644 --- a/tests/test_pose_utils.py +++ b/tests/test_pose_utils.py @@ -385,7 +385,9 @@ def test_rule_based_tagging(multi_animal, video_output): animal_ids=(["B", "W"] if multi_animal else [""]), ).run(verbose=True) - hardcoded_tags = prun.rule_based_annotation(video_output=video_output, frame_limit=50) + hardcoded_tags = prun.rule_based_annotation( + video_output=video_output, frame_limit=50 + ) assert type(hardcoded_tags) == deepof.data.table_dict assert list(hardcoded_tags.values())[0].shape[1] == (13 if multi_animal else 3) -- GitLab