Commit cdd65a7b authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored preprocess.py and pose_utils.py

parent ad481f9b
......@@ -8,7 +8,7 @@ checks:
threshold: 4
file-lines:
config:
threshold: 500
threshold: 1000
method-complexity:
config:
threshold: 10
......
......@@ -356,6 +356,7 @@ def max_behaviour(
return np.array(max_array)
# noinspection PyDefaultArgument
def get_hparameters(hparams: dict = {}) -> dict:
"""Returns the most frequent behaviour in a window of window_size frames
......@@ -383,21 +384,18 @@ def get_hparameters(hparams: dict = {}) -> dict:
return defaults
# noinspection PyDefaultArgument
# noinspection PyDefaultArgument,PyProtectedMember
def rule_based_tagging(
tracks: List,
videos: List,
coordinates: Coordinates,
vid_index: int,
animal_ids: List = None,
show: bool = False,
save: bool = False,
fps: float = 0.0,
path: str = os.path.join("./"),
hparams: dict = {},
arena_type: str = "circular",
frame_limit: float = np.inf,
recog_limit: int = 1,
mode: str = None,
fps: float = 0.0,
path: str = os.path.join("."),
hparams: dict = {},
) -> pd.DataFrame:
"""Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
video displaying the information in real time
......@@ -432,6 +430,7 @@ def rule_based_tagging(
value is a boolean indicating trait detection at a given time"""
hparams = get_hparameters(hparams)
animal_ids = coordinates._animal_ids
vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
......@@ -439,7 +438,7 @@ def rule_based_tagging(
speeds = coordinates.get_coords(speed=1)[vid_name]
arena_abs = coordinates.get_arenas[1][0]
arena, h, w = deepof.utils.recognize_arena(
videos, vid_index, path, recog_limit, arena_type
videos, vid_index, path, recog_limit, coordinates._arena
)
# Dictionary with motives per frame
......@@ -557,7 +556,7 @@ def rule_based_tagging(
)
)
if any([show, save]):
if mode in ["show", "save"]:
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
# Keep track of the frame number, to align with the tracking data
......@@ -720,13 +719,13 @@ def rule_based_tagging(
2,
)
if show: # pragma: no cover
if mode == "show": # pragma: no cover
cv2.imshow("frame", frame)
if cv2.waitKey(1) == ord("q"):
break
if save:
if mode == "save":
if writer is None:
# Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
......
......@@ -45,7 +45,7 @@ class project:
self,
video_format: str = ".mp4",
table_format: str = ".h5",
path: str = ".",
path: str = os.path.join("."),
exp_conditions: dict = None,
subset_condition: list = None,
arena: str = "circular",
......@@ -55,6 +55,7 @@ class project:
ego: str = False,
angles: bool = True,
model: str = "mouse_topview",
animal_ids: List = None,
):
self.path = path
......@@ -77,6 +78,7 @@ class project:
self.ego = ego
self.angles = angles
self.scales = self.get_scale
self.animal_ids = animal_ids
model_dict = {"mouse_topview": connect_mouse_topview()}
self.connectivity = model_dict[model]
......@@ -306,6 +308,7 @@ class project:
exp_conditions=self.exp_conditions,
distances=distances,
angles=angles,
animal_ids=self.animal_ids,
)
......@@ -328,6 +331,7 @@ class coordinates:
exp_conditions: dict = None,
distances: dict = None,
angles: dict = None,
animal_ids: List = None,
):
self._tables = tables
self.distances = distances
......@@ -338,6 +342,7 @@ class coordinates:
self._arena_dims = arena_dims
self._scales = scales
self._quality = quality
self._animal_ids = animal_ids
def __str__(self):
if self._exp_conditions:
......
......@@ -349,6 +349,7 @@ def test_rule_based_tagging():
angles=False,
video_format=".mp4",
table_format=".h5",
animal_ids=None,
).run(verbose=True)
hardcoded_tags = rule_based_tagging(
......@@ -357,7 +358,7 @@ def test_rule_based_tagging():
prun,
vid_index=0,
path=os.path.join(".", "tests", "test_examples", "Videos"),
save=True,
mode="save",
frame_limit=100,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment