Commit 8506a7fa authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated close_double_contact in deepof.utils and its tests

parent 51c3d458
......@@ -252,8 +252,8 @@ def smooth_mult_trajectory(series: np.array, alpha: float = 0.15) -> np.array:
Parameters:
- series (numpy.array): 1D trajectory array with N (instances) - alpha (float): 0 <= alpha <= 1;
indicates the inverse weight assigned to previous observations. Higher (alpha~1) indicates less smoothing; lower
indicates more (alpha~0)
indicates the inverse weight assigned to previous observations. Higher (alpha~1) indicates less smoothing;
lower indicates more (alpha~0)
Returns:
- smoothed_series (np.array): smoothed version of the input, with equal shape"""
......@@ -291,25 +291,41 @@ def close_single_contact(
return close_contact
# Side by side (noses and tails close)
# def close_double_contact(pos_dict, left1, right2, left2, right2, tol, rev=False):
# """Takes DLC dataframe as input. Returns True when mice are side by side"""
# w_nose = pos_dict["W_Nose"]
# b_nose = pos_dict["B_Nose"]
# w_tail = pos_dict["W_Tail_base"]
# b_tail = pos_dict["B_Tail_base"]
#
# if rev:
# return (
# np.linalg.norm(w_nose - b_tail) < tol
# and np.linalg.norm(w_tail - b_nose) < tol
# )
#
# else:
# return (
# np.linalg.norm(w_nose - b_nose) < tol
# and np.linalg.norm(w_tail - b_tail) < tol
# )
def close_double_contact(
pos_dframe: pd.DataFrame,
left1: str,
left2: str,
right1: str,
right2: str,
tol: float,
rev: bool = False,
) -> np.array:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
Parameters:
- pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
to two-animal experiments.
- left1 (string): First contact point of animal 1
- left2 (string): Second contact point of animal 1
- right1 (string): First contact point of animal 2
- right2 (string): Second contact point of animal 2
- tol (float)
Returns:
- double_contact (np.array): True if the distance between the two specified points
is less than tol, False otherwise"""
if rev:
double_contact = (
np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) < tol
) & (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) < tol)
else:
double_contact = (
np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) < tol
) & (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) < tol)
return double_contact
def recognize_arena(
......
......@@ -325,7 +325,7 @@ def test_smooth_mult_trajectory(alpha, series):
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
),
),
tol=st.floats(min_value=0.01, max_value=4.98, allow_infinity=False),
tol=st.floats(min_value=0.01, max_value=4.98),
)
def test_close_single_contact(pos_dframe, tol):
......@@ -336,3 +336,36 @@ def test_close_single_contact(pos_dframe, tol):
close_contact = close_single_contact(pos_dframe, "bpart1", "bpart2", tol)
assert close_contact.dtype == bool
assert np.array(close_contact).shape[0] <= pos_dframe.shape[0]
@settings(deadline=None)
@given(
pos_dframe=data_frames(
index=range_indexes(min_size=5),
columns=columns(["X1", "y1", "X2", "y2", "X3", "y3", "X4", "y4"], dtype=float),
rows=st.tuples(
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
st.floats(min_value=1, max_value=10, allow_nan=False, allow_infinity=False),
),
),
tol=st.floats(min_value=0.01, max_value=4.98),
rev=st.booleans(),
)
def test_close_double_contact(pos_dframe, tol, rev):
idx = pd.MultiIndex.from_product(
[["bpart1", "bpart2", "bpart3", "bpart4"], ["X", "y"]],
names=["bodyparts", "coords"],
)
pos_dframe.columns = idx
close_contact = close_double_contact(
pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, rev
)
assert close_contact.dtype == bool
assert np.array(close_contact).shape[0] <= pos_dframe.shape[0]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment