From b2f4bbb64a71b2aabcc257ed16b8794e325f4ffd Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Mon, 18 Jan 2021 23:42:03 +0100
Subject: [PATCH] Added testing examples for multi animal deepof pipeline

---
 deepof/pose_utils.py     | 12 ++++++++----
 tests/test_pose_utils.py |  4 +++-
 2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py
index 40b7ad59..f5016d8d 100644
--- a/deepof/pose_utils.py
+++ b/deepof/pose_utils.py
@@ -556,7 +556,7 @@ def rule_based_tagging(
                 hparams["huddle_forward"],
                 hparams["huddle_spine"],
                 hparams["huddle_speed"],
-                animal_id=_id
+                animal_id=_id,
             )
         )
 
@@ -602,13 +602,14 @@ def tag_rulebased_frames(
         else:
             return 150, 255, 150
 
-    zipped_pos = zip(
+    zipped_pos = list(zip(
         animal_ids,
         [corners["downleft"], corners["downright"]],
         [corners["upleft"], corners["upright"]],
-    )
+    ))
 
     if len(animal_ids) > 1:
+
         if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
             write_on_frame("Nose-Nose", conditional_pos())
         if (
@@ -659,7 +660,10 @@ def tag_rulebased_frames(
             colcond = hparams["huddle_speed"] > frame_speeds
 
         write_on_frame(
-            str(np.round(frame_speeds, 2)) + " mmpf",
+            str(
+                np.round((frame_speeds if len(animal_ids) == 1 else frame_speeds[_id]), 2)
+            )
+            + " mmpf",
             up_pos,
             conditional_col(cond=colcond),
         )
diff --git a/tests/test_pose_utils.py b/tests/test_pose_utils.py
index 6f948a18..574f5e66 100644
--- a/tests/test_pose_utils.py
+++ b/tests/test_pose_utils.py
@@ -385,7 +385,9 @@ def test_rule_based_tagging(multi_animal, video_output):
         animal_ids=(["B", "W"] if multi_animal else [""]),
     ).run(verbose=True)
 
-    hardcoded_tags = prun.rule_based_annotation(video_output=video_output, frame_limit=50)
+    hardcoded_tags = prun.rule_based_annotation(
+        video_output=video_output, frame_limit=50
+    )
 
     assert type(hardcoded_tags) == deepof.data.table_dict
     assert list(hardcoded_tags.values())[0].shape[1] == (13 if multi_animal else 3)
-- 
GitLab