Commit b657d2ba authored by lucas_miranda's avatar lucas_miranda
Browse files

refactored data.py and pose_utils.py


Former-commit-id: 678d6639
parent 65e7b0dc
......@@ -125,7 +125,7 @@ class project:
][0]
)
self.scales = self.get_scale
self.scales, self.arena_params = self.get_scale
# Set the rest of the init parameters
self.angles = True
......@@ -190,9 +190,11 @@ class project:
def get_scale(self) -> np.array:
"""Returns the arena as recognised from the videos"""
scales = []
arena_params = []
if self.arena in ["circular"]:
scales = []
for vid_index, _ in enumerate(self.videos):
ellipse = deepof.utils.recognize_arena(
self.videos,
......@@ -204,14 +206,15 @@ class project:
)[0]
scales.append(
list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1]]) * 2)
list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1] * 2]))
+ list(self.arena_dims)
)
arena_params.append(ellipse)
else:
raise NotImplementedError("arenas must be set to one of: 'circular'")
return np.array(scales)
return np.array(scales), arena_params
def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
"""Loads videos and tables into dictionaries"""
......@@ -441,9 +444,9 @@ class project:
path=self.path,
quality=quality,
scales=self.scales,
arena_params=self.arena_params,
tables=tables,
videos=self.videos,
ellipse_detection=self.ellipse_detection,
)
@subset_condition.setter
......@@ -479,18 +482,18 @@ class coordinates:
path: str,
quality: dict,
scales: np.array,
arena_params: List,
tables: dict,
videos: list,
angles: dict = None,
animal_ids: List = tuple([""]),
distances: dict = None,
exp_conditions: dict = None,
ellipse_detection: tf.keras.models.Model = None,
):
self._animal_ids = animal_ids
self._arena = arena
self._arena_detection = arena_detection
self._ellipse_detection_model = ellipse_detection
self._arena_params = arena_params
self._arena_dims = arena_dims
self._exp_conditions = exp_conditions
self._path = path
......@@ -820,26 +823,18 @@ class coordinates:
"""Annotates coordinates using a simple rule-based pipeline"""
tag_dict = {}
# noinspection PyTypeChecker
coords = self.get_coords(center=False)
dists = self.get_distances()
speeds = self.get_coords(speed=1)
# noinspection PyTypeChecker
for key in tqdm(self._tables.keys()):
video = [vid for vid in self._videos if key + "DLC" in vid][0]
tag_dict[key] = deepof.pose_utils.rule_based_tagging(
list(self._tables.keys()),
self._videos,
self,
coords,
dists,
speeds,
self._videos.index(video),
arena_type=self._arena,
arena_detection_mode=self._arena_detection,
ellipse_detection_model=self._ellipse_detection_model,
recog_limit=1,
path=os.path.join(self._path, "Videos"),
coords=coords,
dists=dists,
speeds=speeds,
video=[vid for vid in self._videos if key + "DLC" in vid][0],
params=params,
)
......
......@@ -598,18 +598,11 @@ def frame_corners(w, h, corners: dict = {}):
# noinspection PyDefaultArgument,PyProtectedMember
def rule_based_tagging(
tracks: List,
videos: List,
coordinates: Coordinates,
coords: Any,
dists: Any,
speeds: Any,
vid_index: int,
arena_type: str,
arena_detection_mode: str,
ellipse_detection_model: tf.keras.models.Model = None,
recog_limit: int = 100,
path: str = os.path.join("."),
video: str,
params: dict = {},
) -> pd.DataFrame:
"""Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
......@@ -618,10 +611,10 @@ def rule_based_tagging(
Parameters:
- tracks (list): list containing experiment IDs as strings
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
- coords (deepof.preprocessing.table_dict): table_dict with already processed coordinates
- dists (deepof.preprocessing.table_dict): table_dict with already processed distances
- speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds
- coordinates (deepof.data.coordinates): coordinates object containing the project information
- coords (deepof.data.table_dict): table_dict with already processed coordinates
- dists (deepof.data.table_dict): table_dict with already processed distances
- speeds (deepof.data.table_dict): table_dict with already processed speeds
- vid_index (int): index in videos of the experiment to annotate
- path (str): directory in which the experimental data is stored
- recog_limit (int): number of frames to use for arena recognition (100 by default)
......@@ -632,6 +625,13 @@ def rule_based_tagging(
- tag_df (pandas.DataFrame): table with traits as columns and frames as rows. Each
value is a boolean indicating trait detection at a given time"""
# Extract useful information from coordinates object
tracks = list(coordinates._tables.keys())
vid_index = coordinates._videos.index(video)
arena_params = coordinates._arena_params[vid_index]
arena_type = coordinates._arena
params = get_hparameters(params)
animal_ids = coordinates._animal_ids
undercond = "_" if len(animal_ids) > 1 else ""
......@@ -646,15 +646,6 @@ def rule_based_tagging(
speeds = speeds[vid_name]
likelihoods = coordinates.get_quality()[vid_name]
arena_abs = coordinates.get_arenas[1][0]
arena, h, w = deepof.utils.recognize_arena(
videos,
vid_index,
path,
recog_limit,
coordinates._arena,
arena_detection_mode,
ellipse_detection_model,
)
# Dictionary with motives per frame
tag_dict = {}
......@@ -674,7 +665,7 @@ def rule_based_tagging(
def onebyone_contact(bparts: List):
"""Returns a smooth boolean array with 1to1 contacts between two mice"""
nonlocal coords, animal_ids, params, arena_abs, arena
nonlocal coords, animal_ids, params, arena_abs, arena_params
try:
left = animal_ids[0] + bparts[0]
......@@ -693,14 +684,14 @@ def rule_based_tagging(
(right if not isinstance(left, list) else left),
params["close_contact_tol"],
arena_abs,
arena[1][1],
arena_params[1][1],
)
)
def twobytwo_contact(rev):
"""Returns a smooth boolean array with side by side contacts between two mice"""
nonlocal coords, animal_ids, params, arena_abs, arena
nonlocal coords, animal_ids, params, arena_abs, arena_params
return deepof.utils.smooth_boolean_array(
close_double_contact(
coords,
......@@ -711,7 +702,7 @@ def rule_based_tagging(
params["side_contact_tol"],
rev=rev,
arena_abs=arena_abs,
arena_rel=arena[1][1],
arena_rel=arena_params[1][1],
)
)
......@@ -776,7 +767,7 @@ def rule_based_tagging(
tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
climb_wall(
arena_type,
arena,
arena_params,
coords,
params["climb_tol"],
_id + undercond + "Nose",
......@@ -786,7 +777,7 @@ def rule_based_tagging(
sniff_object(
speeds,
arena_type,
arena,
arena_params,
coords,
params["climb_tol"],
params["huddle_speed"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment