Commit 3e5b4070 authored by lucas_miranda's avatar lucas_miranda
Browse files

rule_based_annotation in data.py is now ~16x faster!

parent b6e6cdbd
......@@ -15,8 +15,10 @@ Contains methods for generating training and test sets ready for model training.
"""
from collections import defaultdict
from joblib import delayed, Parallel, parallel_backend
from typing import Dict, List
from pandas_profiling import ProfileReport
from psutil import cpu_count
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
from sklearn.manifold import TSNE
......@@ -87,7 +89,7 @@ class project:
self.animal_ids = animal_ids
self.subset_condition = None
self.distances = "All"
self.distances = "all"
self.ego = False
self.angles = True
......@@ -110,7 +112,7 @@ class project:
@property
def distances(self):
"""List. If not 'All', sets the body parts among which the
"""List. If not 'all', sets the body parts among which the
distances will be computed"""
return self._distances
......@@ -255,7 +257,7 @@ class project:
print("Computing distances...")
nodes = self.distances
if nodes == "All":
if nodes == "all":
nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
assert [
......@@ -323,7 +325,7 @@ class project:
return angle_dict
def run(self, verbose: bool = False) -> Coordinates:
def run(self, verbose: bool = True) -> Coordinates:
"""Generates a dataset using all the options specified during initialization"""
tables, quality = self.load_tables(verbose)
......@@ -644,12 +646,18 @@ class coordinates:
"""Annotates coordinates using a simple rule-based pipeline"""
tag_dict = {}
for idx, key in tqdm(enumerate(self._tables.keys()), total=len(self._videos)):
coords = self.get_coords()
speeds = self.get_coords(speed=1)
for key in tqdm(self._tables.keys()):
video = [vid for vid in self._videos if key in vid][0]
tag_dict[key] = deepof.pose_utils.rule_based_tagging(
list(self._tables.keys()),
self._videos,
self,
idx,
coords,
speeds,
self._videos.index(video),
arena_type=self._arena,
recog_limit=1,
path=os.path.join(self._path, "Videos"),
......@@ -657,16 +665,10 @@ class coordinates:
)
if video_output: # pragma: no cover
if type(video_output) == list:
vid_idxs = video_output
elif video_output == "all":
vid_idxs = list(self._tables.keys())
else:
raise AttributeError(
"Video output must be either 'all' or a list with the names of the videos to render"
)
for idx in vid_idxs:
def output_video(idx):
"""Outputs a single annotated video. Enclosed in a function to enable parallelization"""
deepof.pose_utils.rule_based_video(
self,
list(self._tables.keys()),
......@@ -679,6 +681,19 @@ class coordinates:
hparams=hparams,
)
if type(video_output) == list:
vid_idxs = video_output
elif video_output == "all":
vid_idxs = list(self._tables.keys())
else:
raise AttributeError(
"Video output must be either 'all' or a list with the names of the videos to render"
)
njobs = cpu_count(logical=True)
with parallel_backend("threading", n_jobs=njobs):
Parallel()(delayed(output_video)(key) for key in vid_idxs)
return table_dict(
tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
)
......
......@@ -423,6 +423,8 @@ def rule_based_tagging(
tracks: List,
videos: List,
coordinates: Coordinates,
coords: Any,
speeds: Any,
vid_index: int,
arena_type: str,
recog_limit: int = 1,
......@@ -436,6 +438,8 @@ def rule_based_tagging(
- tracks (list): list containing experiment IDs as strings
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
- coords (deepof.preprocessing.table_dict): table_dict with already processed coordinates
- speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds
- vid_index (int): index in videos of the experiment to annotate
- path (str): directory in which the experimental data is stored
- recog_limit (int): number of frames to use for arena recognition (1 by default)
......@@ -464,8 +468,8 @@ def rule_based_tagging(
except IndexError:
vid_name = tracks[vid_index]
coords = coordinates.get_coords()[vid_name]
speeds = coordinates.get_coords(speed=1)[vid_name]
coords = coords[vid_name]
speeds = speeds[vid_name]
arena_abs = coordinates.get_arenas[1][0]
arena, h, w = deepof.utils.recognize_arena(
videos, vid_index, path, recog_limit, coordinates._arena
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment