Commit 88294454 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for arena-recognition functions

parent 9f0cb8e6
......@@ -10,8 +10,9 @@ joblib~=0.16.0
sklearn~=0.0
scikit-learn~=0.23.2
tqdm~=4.42.0
tensorflow~=2.2.0
tensorflow~=2.0.0
hypothesis~=5.29.0
dash~=1.11.0
plotly~=4.5.0
setuptools~=49.6.0
\ No newline at end of file
setuptools~=49.6.0
pytest~=5.3.5
\ No newline at end of file
......@@ -339,16 +339,16 @@ def recognize_arena(
"""Returns numpy.array with information about the arena recognised from the first frames
of the video. WARNING: estimates won't be reliable if the camera moves along the video.
Parameters:
- videos (list): relative paths of the videos to analise
- vid_index (int): element of videos to use
- path (string): full path of the directory where the videos are
- recoglimit (int): number of frames to use for position estimates
- arena_type (string): arena type; must be one of ['circular']
Parameters:
- videos (list): relative paths of the videos to analise
- vid_index (int): element of videos to use
- path (string): full path of the directory where the videos are
- recoglimit (int): number of frames to use for position estimates
- arena_type (string): arena type; must be one of ['circular']
Returns:
- arena (np.array): 1D-array containing information about the arena.
"circular" (3-element-array) -> x-y position of the center and the radius"""
Returns:
- arena (np.array): 1D-array containing information about the arena.
"circular" (3-element-array) -> x-y position of the center and the radius"""
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
......@@ -377,11 +377,11 @@ def recognize_arena(
def circular_arena_recognition(frame: np.array) -> np.array:
"""Returns x,y position of the center and the radius of the recognised arena
Parameters:
- frame (np.array): numpy.array representing an individual frame of a video
- frame (np.array): numpy.array representing an individual frame of a video
Returns:
- circles (np.array): 3-element-array containing x,y positions of the center
of the arena, and a third value indicating the radius"""
Returns:
- circles (np.array): 3-element-array containing x,y positions of the center
of the arena, and a third value indicating the radius"""
# Convert image to greyscale, threshold it, blur it and detect the biggest best fitting circle
# using the Hough algorithm
......@@ -408,13 +408,32 @@ def circular_arena_recognition(frame: np.array) -> np.array:
return circles[0]
def climb_wall(arena, pos_dict, tol, mouse):
"""Returns True if the specified mouse is climbing the wall"""
def climb_wall(
arena_type: str, arena: np.array, pos_dict: pd.DataFrame, tol: float, nose: str
) -> np.array:
"""Returns True if the specified mouse is climbing the wall
Parameters:
- arena_type (str): arena type; must be one of ['circular']
- arena (np.array): contains arena location and shape details
- pos_dict (table_dict): position over time for all videos in a project
- tol (float): minimum tolerance to report a hit
- nose (str): indicates the name of the body part representing the nose of
the selected animal
Returns:
- climbing (np.array): boolean array. True if selected animal
is climbing the walls of the arena"""
nose = pos_dict[nose]
nose = pos_dict[mouse + "_Nose"]
center = np.array(arena[:2])
if arena_type == "circular":
center = np.array(arena[:2])
climbing = np.linalg.norm(nose - center, axis=1) > (arena[2] + tol)
else:
raise NotImplementedError("Supported values for arena_type are ['circular']")
return np.linalg.norm(nose - center) > arena[2] + tol
return climbing
def rolling_speed(dframe, typ, pause=10, rounds=5, order=1):
......
This diff is collapsed.
This diff is collapsed.
......@@ -7,6 +7,8 @@ from hypothesis.extra.numpy import arrays
from hypothesis.extra.pandas import range_indexes, columns, data_frames
from scipy.spatial import distance
from deepof.utils import *
import deepof.preprocess
import pytest
# QUALITY CONTROL AND PREPROCESSING #
......@@ -376,7 +378,7 @@ def test_close_double_contact(pos_dframe, tol, rev):
@given(indexes=st.data())
def test_recognize_arena_and_subfunctions(indexes):
path = "./tests/test_examples/"
path = "./tests/test_examples/Videos/"
videos = [i for i in os.listdir(path) if i.endswith("mp4")]
vid_index = indexes.draw(st.integers(min_value=0, max_value=len(videos) - 1))
......@@ -384,3 +386,39 @@ def test_recognize_arena_and_subfunctions(indexes):
assert recognize_arena(videos, vid_index, path, recoglimit, "") == 0
assert len(recognize_arena(videos, vid_index, path, recoglimit, "circular")) == 3
@settings(deadline=None)
@given(
arena=st.lists(
min_size=3, max_size=3, elements=st.integers(min_value=300, max_value=500)
),
tol=st.data(),
)
def test_climb_wall(arena, tol):
tol1 = tol.draw(st.floats(min_value=0.001, max_value=10))
tol2 = tol.draw(st.floats(min_value=tol1, max_value=10))
prun = (
deepof.preprocess.project(
path="./tests/test_examples",
arena="circular",
arena_dims=[arena[0]],
angles=False,
video_format=".mp4",
table_format=".h5",
)
.run(verbose=False)
.get_coords()
)
climb1 = climb_wall("circular", arena, prun["test"], tol1, nose="Nose")
climb2 = climb_wall("circular", arena, prun["test"], tol2, nose="Nose")
assert climb1.dtype == bool
assert climb2.dtype == bool
assert np.sum(climb1) >= np.sum(climb2)
with pytest.raises(NotImplementedError):
climb_wall("", arena, prun["test"], tol1, nose="Nose")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment