Commit 47c1a86d authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored preprocess.py and pose_utils.py

parent 8e2ff965
......@@ -405,14 +405,12 @@ def rule_based_tagging(
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
- vid_index (int): index in videos of the experiment to annotate
- animal_ids (list): IDs identifying multiple animals on the arena. None if there's only one
- show (bool): if True, enables the display of the annotated video in a separate window
- save (bool): if True, saves the annotated video to an mp4 file
- mode (str): if show, enables the display of the annotated video in a separate window, saves to mp4 file
if save
- fps (float): frames per second of the analysed video. Same as input by default
- path (str): directory in which the experimental data is stored
- frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
- recog_limit (int): number of frames to use for arena recognition (1 by default)
- arena_type (str): type of the arena used in the experiments. Must be one of 'circular'"
- hparams (dict): dictionary to overwrite the default values of the hyperparameters of the functions
that the rule-based pose estimation utilizes. Values can be:
- speed_pause (int): size of the rolling window to use when computing speeds
......@@ -456,26 +454,6 @@ def rule_based_tagging(
arena[2],
)
)
tag_dict[animal_ids[0] + "_nose2tail"] = deepof.utils.smooth_boolean_array(
close_single_contact(
coords,
animal_ids[0] + "_Nose",
animal_ids[1] + "_Tail_base",
hparams["close_contact_tol"],
arena_abs,
arena[2],
)
)
tag_dict[animal_ids[1] + "_nose2tail"] = deepof.utils.smooth_boolean_array(
close_single_contact(
coords,
animal_ids[1] + "_Nose",
animal_ids[0] + "_Tail_base",
hparams["close_contact_tol"],
arena_abs,
arena[2],
)
)
tag_dict["sidebyside"] = deepof.utils.smooth_boolean_array(
close_double_contact(
coords,
......@@ -502,6 +480,17 @@ def rule_based_tagging(
arena_rel=arena[2],
)
)
for _id in animal_ids:
tag_dict[_id + "_nose2tail"] = deepof.utils.smooth_boolean_array(
close_single_contact(
coords,
_id + "_Nose",
[i for i in animal_ids if i != _id][0] + "_Tail_base",
hparams["close_contact_tol"],
arena_abs,
arena[2],
)
)
for _id in animal_ids:
tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
following_path(
......
......@@ -51,7 +51,7 @@ class project:
smooth_alpha: float = 0.1,
arena_dims: tuple = (1,),
model: str = "mouse_topview",
animal_ids: List = None,
animal_ids: List = (""),
):
self.path = path
......@@ -368,7 +368,7 @@ class coordinates:
exp_conditions: dict = None,
distances: dict = None,
angles: dict = None,
animal_ids: List = None,
animal_ids: List = (""),
):
self._tables = tables
self.distances = distances
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment