Commit 6dac3405 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for max_behaviour in utils.py

parent 4f990b3e
......@@ -965,14 +965,34 @@ def tag_video(
return tagdf, arena
def max_behaviour(array, window_size=50):
"""Returns the most frequent behaviour in a window of window_size frames"""
array = array.drop(["bspeed", "wspeed"], axis=1).astype("float")
win_array = array.rolling(window_size, center=True).sum()[::50]
def max_behaviour(
behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
) -> np.array:
"""Returns the most frequent behaviour in a window of window_size frames
Parameters:
- behaviour_dframe (pd.DataFrame): boolean matrix containing occurrence
of tagged behaviours per frame in the video
- window_size (int): size of the window to use when computing
the maximum behaviour per time slot
- stepped (bool): sliding windows don't overlap if True. False by default
Returns:
- max_array (np.array): string array with the most common behaviour per instance
of the sliding window"""
speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()]
behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).astype("float")
win_array = behaviour_dframe.rolling(window_size, center=True).sum()
if stepped:
win_array = win_array[::window_size]
max_array = win_array[1:].idxmax(axis=1)
return list(max_array)
##### MACHINE LEARNING FUNCTIONS #####
return np.array(max_array)
# MACHINE LEARNING FUNCTIONS #
def gmm_compute(x, n_components, cv_type):
......
......@@ -619,3 +619,31 @@ def test_single_behaviour_analysis(sampler):
assert type(out[0]) == dict
if stat_tests:
assert type(out[0]) == dict
@settings(
deadline=None, suppress_health_check=[HealthCheck.too_slow],
)
@given(
behaviour_dframe=data_frames(
index=range_indexes(min_size=100, max_size=1000),
columns=columns(
["d1", "d2", "d3", "d4", "speed1"], dtype=bool, elements=st.booleans(),
),
),
window_size=st.data(),
)
def test_max_behaviour(behaviour_dframe, window_size):
wsize1 = window_size.draw(st.integers(min_value=5, max_value=50))
wsize2 = window_size.draw(st.integers(min_value=wsize1, max_value=50))
maxbe1 = max_behaviour(behaviour_dframe, wsize1)
maxbe2 = max_behaviour(behaviour_dframe, wsize2)
assert type(maxbe1) == np.ndarray
assert type(maxbe2) == np.ndarray
assert type(maxbe1[wsize1 // 2 + 1]) == str
assert type(maxbe1[wsize2 // 2 + 1]) == str
assert maxbe1[wsize1 // 2 + 1] in behaviour_dframe.columns
assert maxbe2[wsize2 // 2 + 1] in behaviour_dframe.columns
assert len(maxbe1) >= len(maxbe2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment