Commit 86cdcdee authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for arena-recognition functions

parent c15ff730
......@@ -347,8 +347,8 @@ def recognize_arena(
- arena_type (string): arena type; must be one of ['circular']
Returns:
- arena (np.array): 3-element-array containing information about the arena.
"circular" -> it returns the radius and x-y position of the center"""
- arena (np.array): 1D-array containing information about the arena.
"circular" (3-element-array) -> x-y position of the center and the radius"""
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
......@@ -374,8 +374,14 @@ def recognize_arena(
return arena
def circular_arena_recognition(frame):
"""Returns x,y position of the center and the radius of the recognised arena"""
def circular_arena_recognition(frame: np.array) -> np.array:
"""Returns x,y position of the center and the radius of the recognised arena
Parameters:
- frame (np.array): numpy.array representing an individual frame of a video
Returns:
- circles (np.array): 3-element-array containing x,y positions of the center
of the arena, and a third value indicating the radius"""
# Convert image to greyscale, threshold it, blur it and detect the biggest best fitting circle
# using the Hough algorithm
......@@ -402,7 +408,7 @@ def circular_arena_recognition(frame):
return circles[0]
def climb_wall(arena, pos_dict, fnum, tol, mouse):
def climb_wall(arena, pos_dict, tol, mouse):
"""Returns True if the specified mouse is climbing the wall"""
nose = pos_dict[mouse + "_Nose"]
......
......@@ -372,8 +372,8 @@ def test_close_double_contact(pos_dframe, tol, rev):
assert np.array(close_contact).shape[0] <= pos_dframe.shape[0]
@given(indexes=st.data(), arena_type=st.integers(min_value=0, max_value=0))
def test_recognize_arena(indexes, arena_type):
@given(indexes=st.data())
def test_recognize_arena_and_subfunctions(indexes):
path = "./tests/test_examples/"
videos = [i for i in os.listdir(path) if i.endswith("mp4")]
......@@ -381,4 +381,8 @@ def test_recognize_arena(indexes, arena_type):
vid_index = indexes.draw(st.integers(min_value=0, max_value=len(videos) - 1))
recoglimit = indexes.draw(st.integers(min_value=1, max_value=10))
assert recognize_arena(videos, vid_index, path, recoglimit, arena_type) == 0
assert recognize_arena(videos, vid_index, path, recoglimit, 0) == 0
assert len(recognize_arena(videos, vid_index, path, recoglimit, "circular")) == 3
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment