Commit 74993ad3 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added testing examples for multi animal deepof pipeline

parent 3d8d731e
......@@ -703,7 +703,9 @@ class coordinates:
tag_dict = {}
# noinspection PyTypeChecker
coords = self.get_coords(center=False)
dists = self.get_distances()
speeds = self.get_coords(speed=1)
for key in tqdm(self._tables.keys()):
video = [vid for vid in self._videos if key + "DLC" in vid][0]
......@@ -712,6 +714,7 @@ class coordinates:
self._videos,
self,
coords,
dists,
speeds,
self._videos.index(video),
arena_type=self._arena,
......
......@@ -429,6 +429,7 @@ def rule_based_tagging(
videos: List,
coordinates: Coordinates,
coords: Any,
dists: Any,
speeds: Any,
vid_index: int,
arena_type: str,
......@@ -444,6 +445,7 @@ def rule_based_tagging(
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
- coords (deepof.preprocessing.table_dict): table_dict with already processed coordinates
- dists (deepof.preprocessing.table_dict): table_dict with already processed distances
- speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds
- vid_index (int): index in videos of the experiment to annotate
- path (str): directory in which the experimental data is stored
......@@ -474,6 +476,7 @@ def rule_based_tagging(
vid_name = tracks[vid_index]
coords = coords[vid_name]
dists = dists[vid_name]
speeds = speeds[vid_name]
arena_abs = coordinates.get_arenas[1][0]
arena, h, w = deepof.utils.recognize_arena(
......@@ -532,7 +535,7 @@ def rule_based_tagging(
for _id in animal_ids:
tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
following_path(
coords[vid_name],
dists,
coords,
follower=_id,
followed=[i for i in animal_ids if i != _id][0],
......@@ -553,6 +556,7 @@ def rule_based_tagging(
hparams["huddle_forward"],
hparams["huddle_spine"],
hparams["huddle_speed"],
animal_id=_id
)
)
......
......@@ -364,7 +364,7 @@ def test_frame_corners(w, h):
)
def test_rule_based_tagging(multi_animal, video_output):
multi_animal = False
multi_animal = True
if video_output:
video_output = ["test"]
......@@ -388,4 +388,4 @@ def test_rule_based_tagging(multi_animal, video_output):
hardcoded_tags = prun.rule_based_annotation(video_output=video_output, frame_limit=50)
assert type(hardcoded_tags) == deepof.data.table_dict
assert list(hardcoded_tags.values())[0].shape[1] == 3
assert list(hardcoded_tags.values())[0].shape[1] == (13 if multi_animal else 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment