From 6f74d1f734bd6e0bc47d813d3d73df7d7993d5f0 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 15 Sep 2020 00:31:56 +0200
Subject: [PATCH] Added first tests for rule_based_tagging

---
 deepof/preprocess.py |  8 ++---
 deepof/utils.py      | 76 +++++++++++++++++++++-----------------------
 deepof/visuals.py    |  2 +-
 tests/test_utils.py  | 52 +++++++++++++++++++++++-------
 4 files changed, 82 insertions(+), 56 deletions(-)

diff --git a/deepof/preprocess.py b/deepof/preprocess.py
index 5ba39252..2bd16d7d 100644
--- a/deepof/preprocess.py
+++ b/deepof/preprocess.py
@@ -163,7 +163,7 @@ class project:
                             vid_index,
                             path=self.video_path,
                             arena_type=self.arena,
-                        )
+                        )[0]
                         * 2
                     )
                     + self.arena_dims
@@ -376,7 +376,7 @@ class coordinates:
                         cols = tab.columns.levels[0]
                     except AttributeError:
                         cols = tab.columns
-                    vel = rolling_speed(tab, typ="coords", order=order + 1)
+                    vel = rolling_speed(tab, deriv=order + 1)
                     vel.columns = cols
                     tabs[key] = vel
 
@@ -422,7 +422,7 @@ class coordinates:
                             cols = tab.columns.levels[0]
                         except AttributeError:
                             cols = tab.columns
-                        vel = rolling_speed(tab, typ="dists", order=order + 1)
+                        vel = rolling_speed(tab, deriv=order + 1)
                         vel.columns = cols
                         tabs[key] = vel
 
@@ -453,7 +453,7 @@ class coordinates:
                             cols = tab.columns.levels[0]
                         except AttributeError:
                             cols = tab.columns
-                        vel = rolling_speed(tab, typ="dists", order=order + 1)
+                        vel = rolling_speed(tab, deriv=order + 1)
                         vel.columns = cols
                         tabs[key] = vel
 
diff --git a/deepof/utils.py b/deepof/utils.py
index fdfeea34..8a6ab13a 100644
--- a/deepof/utils.py
+++ b/deepof/utils.py
@@ -21,7 +21,7 @@ from typing import Tuple, Any, List, Union, Dict, NewType
 # DEFINE CUSTOM ANNOTATED TYPES #
 
 
-TableDict = NewType("TableDict", Any)
+Coordinates = NewType("Coordinates", Any)
 
 
 # QUALITY CONTROL AND PREPROCESSING #
@@ -408,7 +408,7 @@ def recognize_arena(
 
             # Detect arena and extract positions
             arena = circular_arena_recognition(frame)[0]
-            if h is not None and w is not None:
+            if h is None and w is None:
                 h, w = frame.shape[0], frame.shape[1]
 
         fnum += 1
@@ -551,7 +551,7 @@ def huddle(
         < tol_forward
     )
 
-    spine = ["Spine1", "Center", "Spine2", "Tail_base"]
+    spine = ["Spine_1", "Center", "Spine_2", "Tail_base"]
     spine_dists = []
     for comb in range(2):
         spine_dists.append(
@@ -871,7 +871,7 @@ def cluster_transition_matrix(
 def rule_based_tagging(
     tracks: List,
     videos: List,
-    table_dict: TableDict,
+    coordinates: Coordinates,
     vid_index: int,
     arena_abs: int,
     animal_ids: List = None,
@@ -884,18 +884,13 @@ def rule_based_tagging(
     path: str = os.path.join("./"),
     arena_type: str = "circular",
     classifiers: Dict = None,
-) -> Tuple[pd.DataFrame, Any]:
-    """Outputs a dataframe with the motives registered per frame. If mp4==True, outputs a video in mp4 format"""
-
-    # noinspection PyProtectedMember
-    assert table_dict._type == "merged", (
-        "Table_dict must be of merged type, "
-        "and contain at least position, speed and distance information"
-    )
+) -> pd.DataFrame:
+    """Outputs a dataframe with the motives registered per frame."""
 
     vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
 
-    dframe = table_dict[vid_name]
+    distances = coordinates.get_coords()[vid_name]
+    speeds = coordinates.get_coords(speed=1)[vid_name]
     arena, h, w = recognize_arena(videos, vid_index, path, recog_limit, arena_type)
 
     # Dictionary with motives per frame
@@ -913,15 +908,15 @@ def rule_based_tagging(
                 behavioural_tags.append(_id + behaviour)
 
     else:
-        behavioural_tags.append(["huddle", "climbing", "speed"])
+        behavioural_tags += ["huddle", "climbing", "speed"]
 
-    tag_dict = {tag: np.zeros(dframe.shape[0]) for tag in behavioural_tags}
+    tag_dict = {tag: np.zeros(distances.shape[0]) for tag in behavioural_tags}
 
     if animal_ids:
         # Define behaviours that can be computed on the fly from the distance matrix
         tag_dict["nose2nose"] = smooth_boolean_array(
             close_single_contact(
-                dframe,
+                distances,
                 animal_ids[0] + "_Nose",
                 animal_ids[1] + "_Nose",
                 15.0,
@@ -931,7 +926,7 @@ def rule_based_tagging(
         )
         tag_dict[animal_ids[0] + "_nose2tail"] = smooth_boolean_array(
             close_single_contact(
-                dframe,
+                distances,
                 animal_ids[0] + "_Nose",
                 animal_ids[1] + "_Tail_base",
                 15.0,
@@ -941,7 +936,7 @@ def rule_based_tagging(
         )
         tag_dict[animal_ids[1] + "_nose2tail"] = smooth_boolean_array(
             close_single_contact(
-                dframe,
+                distances,
                 animal_ids[1] + "_Nose",
                 animal_ids[0] + "_Tail_base",
                 15.0,
@@ -951,7 +946,7 @@ def rule_based_tagging(
         )
         tag_dict["sidebyside"] = smooth_boolean_array(
             close_double_contact(
-                dframe,
+                distances,
                 animal_ids[0] + "_Nose",
                 animal_ids[0] + "_Tail_base",
                 animal_ids[1] + "_Nose",
@@ -964,7 +959,7 @@ def rule_based_tagging(
         )
         tag_dict["sidereside"] = smooth_boolean_array(
             close_double_contact(
-                dframe,
+                distances,
                 animal_ids[0] + "_Nose",
                 animal_ids[0] + "_Tail_base",
                 animal_ids[1] + "_Nose",
@@ -978,8 +973,8 @@ def rule_based_tagging(
         for _id in animal_ids:
             tag_dict[_id + "_following"] = smooth_boolean_array(
                 following_path(
-                    dframe[vid_name],
-                    dframe,
+                    distances[vid_name],
+                    distances,
                     follower=_id,
                     followed=[i for i in animal_ids if i != _id][0],
                     frames=20,
@@ -990,38 +985,37 @@ def rule_based_tagging(
                 pd.Series(
                     (
                         spatial.distance.cdist(
-                            np.array(dframe[_id + "_Nose"]), np.array([arena[:2]])
+                            np.array(distances[_id + "_Nose"]), np.array([arena[:2]])
                         )
                         > (w / 200 + arena[2])
-                    ).reshape(dframe.shape[0]),
-                    index=dframe.index,
+                    ).reshape(distances.shape[0]),
+                    index=distances.index,
                 )
             )
-            tag_dict[_id + "speed"] = rolling_speed(
-                dframe[_id + "_Center"], window=speed_pause
-            )
+            tag_dict[_id + "_speed"] = speeds[_id + "_speed"]
 
     else:
+        print(w)
         tag_dict["climbwall"] = smooth_boolean_array(
             pd.Series(
                 (
                     spatial.distance.cdist(
-                        np.array(dframe["Nose"]), np.array([arena[:2]])
+                        np.array(distances["Nose"]), np.array([arena[:2]])
                     )
                     > (w / 200 + arena[2])
-                ).reshape(dframe.shape[0]),
-                index=dframe.index,
+                ).reshape(distances.shape[0]),
+                index=distances.index,
             )
         )
-        tag_dict["speed"] = rolling_speed(dframe["Center"], window=speed_pause)
+        tag_dict["speed"] = speeds["Center"]
 
-    if "huddle" in classifiers:
+    if classifiers and "huddle" in classifiers:
         mouse_X = {
             _id: np.array(
-                dframe[vid_name][
+                distances[vid_name][
                     [
                         j
-                        for j in dframe[vid_name].keys()
+                        for j in distances[vid_name].keys()
                         if (len(j) == 2 and _id in j[0] and _id in j[1])
                     ]
                 ]
@@ -1036,10 +1030,10 @@ def rule_based_tagging(
         try:
             for _id in animal_ids:
                 tag_dict[_id + "_huddle"] = smooth_boolean_array(
-                    huddle(dframe, 25, 25, 5)
+                    huddle(distances, 25, 25, 5)
                 )
         except TypeError:
-            tag_dict["huddle"] = smooth_boolean_array(huddle(dframe, 25, 25, 5))
+            tag_dict["huddle"] = smooth_boolean_array(huddle(distances, 25, 25, 5))
 
     # if any([show, save]):
     #     cap = cv2.VideoCapture(path + videos[vid_index])
@@ -1201,9 +1195,9 @@ def rule_based_tagging(
     # cap.release()
     # cv2.destroyAllWindows()
 
-    tagdf = pd.DataFrame(tag_dict)
+    tag_df = pd.DataFrame(tag_dict)
 
-    return tagdf, arena
+    return tag_df
 
 
 # TODO:
@@ -1211,3 +1205,7 @@ def rule_based_tagging(
 #    - Add digging to rule_based_tagging
 #    - Add center to rule_based_tagging
 #    - Check for features requested by Joeri
+
+#    - Check speed. Avoid recomputing unnecessarily
+#    - Pass thresholds as parameters of the function. Provide defaults (we should tune them in the future)
+#    - Check if attributes I'm asking for (eg arena) are already stored in Table_dict metadata
diff --git a/deepof/visuals.py b/deepof/visuals.py
index c45aaea3..bcc6b9f2 100644
--- a/deepof/visuals.py
+++ b/deepof/visuals.py
@@ -84,7 +84,7 @@ def model_comparison_plot(
         bars.append(
             spl.bar(
                 xpos,
-                m_bic[i * len(n_components_range) : (i + 1) * len(n_components_range)],
+                m_bic[i * len(n_components_range): (i + 1) * len(n_components_range)],
                 color=color,
                 width=0.2,
             )
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d7c447ed..0d27427b 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -376,7 +376,7 @@ def test_close_double_contact(pos_dframe, tol, rev):
     )
     pos_dframe.columns = idx
     close_contact = close_double_contact(
-        pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, rev, 1, 1
+        pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, 1, 1, rev
     )
     assert close_contact.dtype == bool
     assert np.array(close_contact).shape[0] <= pos_dframe.shape[0]
@@ -386,14 +386,21 @@ def test_close_double_contact(pos_dframe, tol, rev):
 @given(indexes=st.data())
 def test_recognize_arena_and_subfunctions(indexes):
 
-    path = "./tests/test_examples/Videos/"
+    path = os.path.join(".", "tests", "test_examples", "Videos")
     videos = [i for i in os.listdir(path) if i.endswith("mp4")]
 
     vid_index = indexes.draw(st.integers(min_value=0, max_value=len(videos) - 1))
     recoglimit = indexes.draw(st.integers(min_value=1, max_value=10))
 
-    assert recognize_arena(videos, vid_index, path, recoglimit, "") == 0
+    assert recognize_arena(videos, vid_index, path, recoglimit, "")[0] == 0
     assert len(recognize_arena(videos, vid_index, path, recoglimit, "circular")) == 3
+    assert len(recognize_arena(videos, vid_index, path, recoglimit, "circular")[0]) == 3
+    assert (
+        type(recognize_arena(videos, vid_index, path, recoglimit, "circular")[1]) == int
+    )
+    assert (
+        type(recognize_arena(videos, vid_index, path, recoglimit, "circular")[2]) == int
+    )
 
 
 @settings(deadline=None)
@@ -410,9 +417,9 @@ def test_climb_wall(arena, tol):
 
     prun = (
         deepof.preprocess.project(
-            path="./tests/test_examples",
+            path=os.path.join(".", "tests", "test_examples"),
             arena="circular",
-            arena_dims=[arena[0]],
+            arena_dims=[arena[2]],
             angles=False,
             video_format=".mp4",
             table_format=".h5",
@@ -495,8 +502,9 @@ def test_rolling_speed(dframe, sampler):
     ),
     tol_forward=st.floats(min_value=0.01, max_value=4.98),
     tol_spine=st.floats(min_value=0.01, max_value=4.98),
+    tol_speed=st.floats(min_value=0.01, max_value=4.98),
 )
-def test_huddle(pos_dframe, tol_forward, tol_spine):
+def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed):
 
     idx = pd.MultiIndex.from_product(
         [
@@ -505,9 +513,9 @@ def test_huddle(pos_dframe, tol_forward, tol_spine):
                 "Right_ear",
                 "Left_fhip",
                 "Right_fhip",
-                "Spine1",
+                "Spine_1",
                 "Center",
-                "Spine2",
+                "Spine_2",
                 "Tail_base",
             ],
             ["X", "y"],
@@ -515,7 +523,7 @@ def test_huddle(pos_dframe, tol_forward, tol_spine):
         names=["bodyparts", "coords"],
     )
     pos_dframe.columns = idx
-    hudd = huddle(pos_dframe, tol_forward, tol_spine)
+    hudd = huddle(pos_dframe, tol_forward, tol_spine, tol_speed)
 
     assert hudd.dtype == bool
     assert np.array(hudd).shape[0] == pos_dframe.shape[0]
@@ -734,6 +742,26 @@ def test_cluster_transition_matrix(sampler, autocorrelation, return_graph):
 
 
 @settings(deadline=None)
-@given()
-def test_rule_based_tagging():
-    pass
\ No newline at end of file
+@given(sampler=st.data())
+def test_rule_based_tagging(sampler):
+
+    prun = deepof.preprocess.project(
+        path=os.path.join(".", "tests", "test_examples"),
+        arena="circular",
+        arena_dims=[380],
+        angles=False,
+        video_format=".mp4",
+        table_format=".h5",
+    ).run(verbose=False)
+
+    hardcoded_tags = rule_based_tagging(
+        list([i + "_" for i in prun.get_coords().keys()]),
+        ["test_video_circular_arena.mp4"],
+        prun,
+        vid_index=0,
+        arena_abs=380,
+        path=os.path.join(".", "tests", "test_examples", "Videos"),
+    )
+
+    assert type(hardcoded_tags) == pd.DataFrame
+    assert hardcoded_tags.shape[1] == 4
-- 
GitLab