diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py index 40b7ad5939b9cf687fe129744d3d103ee737e137..f5016d8d4cd5a36c1c698b4d1383138d8055544c 100644 --- a/deepof/pose_utils.py +++ b/deepof/pose_utils.py @@ -556,7 +556,7 @@ def rule_based_tagging( hparams["huddle_forward"], hparams["huddle_spine"], hparams["huddle_speed"], - animal_id=_id + animal_id=_id, ) ) @@ -602,13 +602,14 @@ def tag_rulebased_frames( else: return 150, 255, 150 - zipped_pos = zip( + zipped_pos = list(zip( animal_ids, [corners["downleft"], corners["downright"]], [corners["upleft"], corners["upright"]], - ) + )) if len(animal_ids) > 1: + if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]: write_on_frame("Nose-Nose", conditional_pos()) if ( @@ -659,7 +660,10 @@ def tag_rulebased_frames( colcond = hparams["huddle_speed"] > frame_speeds write_on_frame( - str(np.round(frame_speeds, 2)) + " mmpf", + str( + np.round((frame_speeds if len(animal_ids) == 1 else frame_speeds[_id]), 2) + ) + + " mmpf", up_pos, conditional_col(cond=colcond), ) diff --git a/tests/test_pose_utils.py b/tests/test_pose_utils.py index 6f948a18d1f6b1e646af03e85ac1ef6ee2a92681..574f5e665c3d99c6082c50a564cd0973d1df64ae 100644 --- a/tests/test_pose_utils.py +++ b/tests/test_pose_utils.py @@ -385,7 +385,9 @@ def test_rule_based_tagging(multi_animal, video_output): animal_ids=(["B", "W"] if multi_animal else [""]), ).run(verbose=True) - hardcoded_tags = prun.rule_based_annotation(video_output=video_output, frame_limit=50) + hardcoded_tags = prun.rule_based_annotation( + video_output=video_output, frame_limit=50 + ) assert type(hardcoded_tags) == deepof.data.table_dict assert list(hardcoded_tags.values())[0].shape[1] == (13 if multi_animal else 3)