diff --git a/deepof/utils.py b/deepof/utils.py index aaa7083705f10e254a94de14520767a36dc1ed30..7c9becdeff741b252ba6ef0254573ada8898ad67 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -130,6 +130,7 @@ def angle(a: np.array, b: np.array, c: np.array) -> np.array: - a (2D np.array): positions over time for a bodypart - b (2D np.array): positions over time for a bodypart - c (2D np.array): positions over time for a bodypart + Returns: - ang (1D np.array): angles between the three-point-instances""" @@ -152,6 +153,7 @@ def angle_trio(bpart_array: np.array) -> np.array: Returns: - ang_trio (2D numpy.array): all-three angles between the three-point-instances""" + a, b, c = bpart_array ang_trio = np.array([angle(a, b, c), angle(a, c, b), angle(b, a, c)]) @@ -170,6 +172,7 @@ def rotate( Returns: - rotated (2D numpy.array): rotated positions over time""" + R = np.array([[np.cos(angles), -np.sin(angles)], [np.sin(angles), np.cos(angles)]]) o = np.atleast_2d(origin) @@ -353,7 +356,7 @@ def recognize_arena( Returns: - arena (np.array): 1D-array containing information about the arena. - "circular" (3-element-array) -> x-y position of the center and the radius""" + "circular" (3-element-array) -> x-y position of the center and the radius""" cap = cv2.VideoCapture(os.path.join(path, videos[vid_index])) @@ -499,6 +502,7 @@ def huddle(pos_dframe: pd.DataFrame, tol_forward: float, tol_spine: float) -> np forward limbs - tol_rear (float): Maximum tolerated average distance between spine body parts + Returns: hudd (np.array): True if the animal is huddling, False otherwise """ @@ -526,38 +530,60 @@ def huddle(pos_dframe: pd.DataFrame, tol_forward: float, tol_spine: float) -> np return hudd -def following_path(distancedf, dframe, follower="B", followed="W", frames=20, tol=0): - """Returns true if follower is closer than tol to the path that followed has walked over - the last specified number of frames""" +def following_path( + distance_dframe: pd.DataFrame, + position_dframe: pd.DataFrame, + follower: str, + followed: str, + frames: int = 20, + tol: float = 0, +) -> np.array: + """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that + followed has walked over the last specified number of frames + + Parameters: + - distance_dframe (pandas.DataFrame): distances between bodyparts; generated by the preprocess module + - position_dframe (pandas.DataFrame): position of bodyparts; generated by the preprocess module + - follower (str) identifier for the animal who's following + - followed (str) identifier for the animal who's followed + - frames (int) frames in which to track whether the process consistently occurs, + - tol (float) Maximum distance for which True is returned + + Returns: + - follow (np.array): boolean sequence, True if conditions are fulfilled, False otherwise""" # Check that follower is close enough to the path that followed has passed though in the last frames - shift_dict = {i: dframe[followed + "_Tail_base"].shift(i) for i in range(frames)} + shift_dict = { + i: position_dframe[followed + "_Tail_base"].shift(i) for i in range(frames) + } dist_df = pd.DataFrame( { - i: np.linalg.norm(dframe[follower + "_Nose"] - shift_dict[i], axis=1) + i: np.linalg.norm( + position_dframe[follower + "_Nose"] - shift_dict[i], axis=1 + ) for i in range(frames) } ) # Check that the animals are oriented follower's nose -> followed's tail right_orient1 = ( - distancedf[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))] - < distancedf[tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))] + distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))] + < distance_dframe[ + tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"])) + ] ) right_orient2 = ( - distancedf[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))] - < distancedf[tuple(sorted([follower + "_Nose", followed + "_Nose"]))] + distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))] + < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))] ) - return pd.Series( - np.all( - np.array([(dist_df.min(axis=1) < tol), right_orient1, right_orient2]), - axis=0, - ), - index=dframe.index, + follow = np.all( + np.array([(dist_df.min(axis=1) < tol), right_orient1, right_orient2]), axis=0, ) + return follow + def single_behaviour_analysis( behaviour_name, diff --git a/tests/test_utils.py b/tests/test_utils.py index 5353bef4891e7292d76c9a441b88176f732af353..3a58ab28dee199f5daa70d1001f6e1663a0bd0eb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -517,8 +517,59 @@ def test_huddle(pos_dframe, tol_forward, tol_spine): pos_dframe.columns = idx hudd = huddle(pos_dframe, tol_forward, tol_spine) - print(hudd) - assert hudd.dtype == bool assert np.array(hudd).shape[0] == pos_dframe.shape[0] assert np.sum(np.array(hudd)) <= pos_dframe.shape[0] + + +@settings(deadline=None) +@given( + distance_dframe=data_frames( + index=range_indexes(min_size=20, max_size=20), + columns=columns( + ["d1", "d2", "d3", "d4",], + dtype=float, + elements=st.floats(min_value=-20, max_value=20), + ), + ), + position_dframe=data_frames( + index=range_indexes(min_size=20, max_size=20), + columns=columns( + ["X1", "y1", "X2", "y2", "X3", "y3", "X4", "y4",], + dtype=float, + elements=st.floats(min_value=-20, max_value=20), + ), + ), + frames=st.integers(min_value=1, max_value=20), + tol=st.floats(min_value=0.01, max_value=4.98), +) +def test_following_path(distance_dframe, position_dframe, frames, tol): + + bparts = [ + "A_Nose", + "B_Nose", + "A_Tail_base", + "B_Tail_base", + ] + + pos_idx = pd.MultiIndex.from_product( + [bparts, ["X", "y"],], names=["bodyparts", "coords"], + ) + + position_dframe.columns = pos_idx + distance_dframe.columns = [c for c in combinations(bparts, 2) if c[0][0] != c[1][0]] + + follow = following_path( + distance_dframe, + position_dframe, + follower="A", + followed="B", + frames=frames, + tol=tol, + ) + + assert follow.dtype == bool + assert len(follow) == position_dframe.shape[0] + assert len(follow) == distance_dframe.shape[0] + assert np.sum(follow) <= position_dframe.shape[0] + assert np.sum(follow) <= distance_dframe.shape[0]