diff --git a/deepof/utils.py b/deepof/utils.py index 00eaf1b436eda3d01a78841e91ec33a0b8615ec1..c6c0131f7df7d63d4d6cd929020fc7f4db35c310 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import multiprocessing import networkx as nx import numpy as np +import os import pandas as pd import regex as re import scipy @@ -329,12 +330,30 @@ def close_double_contact( def recognize_arena( - video, vid_index, path=".", recoglimit=1, arena_type="circular", -): - cap = cv2.VideoCapture(path + video[vid_index]) + videos: list, + vid_index: int, + path: str = ".", + recoglimit: int = 1, + arena_type: str = "circular", +) -> np.array: + """Returns numpy.array with information about the arena recognised from the first frames + of the video. WARNING: estimates won't be reliable if the camera moves along the video. + + Parameters: + - videos (list): relative paths of the videos to analise + - vid_index (int): element of videos to use + - path (string): full path of the directory where the videos are + - recoglimit (int): number of frames to use for position estimates + - arena_type (string): arena type; must be one of ['circular'] + + Returns: + - arena (np.array): 3-element-array containing information about the arena. + "circular" -> it returns the radius and x-y position of the center""" + + cap = cv2.VideoCapture(os.path.join(path, videos[vid_index])) # Loop over the first frames in the video to get resolution and center of the arena - fnum, h, w = 0, None, None + arena, fnum, h, w = False, 0, None, None while cap.isOpened() and fnum < recoglimit: ret, frame = cap.read() @@ -347,7 +366,7 @@ def recognize_arena( # Detect arena and extract positions arena = circular_arena_recognition(frame)[0] - if h == None and w == None: + if h is not None and w is not None: h, w = frame.shape[0], frame.shape[1] fnum += 1 diff --git a/tests/test_examples/test_video_circular_arena.mp4 b/tests/test_examples/test_video_circular_arena.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9b27dd1de4ef8b73a56b255b7d1b961da8cd14b9 Binary files /dev/null and b/tests/test_examples/test_video_circular_arena.mp4 differ diff --git a/tests/test_utils.py b/tests/test_utils.py index 34896c0cbd9e3cc7c45989a222759d6337ff80eb..e3970c320a2c4734d63c5187e450924a6bb6bc78 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -228,7 +228,7 @@ def test_rotate(p): min_value=1, max_value=10, allow_nan=False, allow_infinity=False ), ), - mode_idx=st.integers(min_value=0, max_value=1) + mode_idx=st.integers(min_value=0, max_value=1), ) def test_align_trajectories(data, mode_idx): mode = ["center", "all"][mode_idx] @@ -370,3 +370,15 @@ def test_close_double_contact(pos_dframe, tol, rev): ) assert close_contact.dtype == bool assert np.array(close_contact).shape[0] <= pos_dframe.shape[0] + + +@given(indexes=st.data(), arena_type=st.integers(min_value=0, max_value=0)) +def test_recognize_arena(indexes, arena_type): + + path = "./tests/test_examples/" + videos = [i for i in os.listdir(path) if i.endswith("mp4")] + + vid_index = indexes.draw(st.integers(min_value=0, max_value=len(videos) - 1)) + recoglimit = indexes.draw(st.integers(min_value=1, max_value=10)) + + assert recognize_arena(videos, vid_index, path, recoglimit, arena_type) == 0