diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py index 167749eff165684be95f3538a60973b51e25bdb3..12940ab90ee91020cbd0b0d36da678f562ad299d 100644 --- a/deepof/pose_utils.py +++ b/deepof/pose_utils.py @@ -245,9 +245,8 @@ def dig( animal_id += "_" speed = speed_dframe[animal_id + "Center"] < tol_speed - nose_speed = speed_dframe[animal_id + "Center"] < speed_dframe[animal_id + "Nose"] - likelihood = likelihood_dframe[animal_id + "Nose"] < tol_likelihood - digging = speed & nose_speed & likelihood + nose_likelihood = likelihood_dframe[animal_id + "Nose"] < tol_likelihood + digging = speed & nose_likelihood return digging @@ -770,6 +769,7 @@ def tag_rulebased_frames( if ( tag_dict[_id + undercond + "huddle"][fnum] and not tag_dict[_id + undercond + "climbing"][fnum] + and not tag_dict[_id + undercond + "dig"][fnum] ): write_on_frame("huddle", down_pos) if ( diff --git a/tests/test_pose_utils.py b/tests/test_pose_utils.py index 30cc8f465b61de20f836c790f61d66c3095e4bc4..2ee472ab8101051f9579221ed52879a69f062b17 100644 --- a/tests/test_pose_utils.py +++ b/tests/test_pose_utils.py @@ -127,6 +127,8 @@ def test_climb_wall(center, axes, angle, tol): index=range_indexes(min_size=5), columns=columns( [ + "X0", + "y0", "X1", "y1", "X2", @@ -146,7 +148,7 @@ def test_climb_wall(center, axes, angle, tol): tol_speed=st.floats(min_value=0.01, max_value=4.98), animal_id=st.text(min_size=0, max_size=15, alphabet=string.ascii_lowercase), ) -def test_huddle(pos_dframe, tol_forward, tol_speed, animal_id): +def test_single_animal_traits(pos_dframe, tol_forward, tol_speed, animal_id): _id = animal_id if animal_id != "": @@ -155,6 +157,7 @@ def test_huddle(pos_dframe, tol_forward, tol_speed, animal_id): idx = pd.MultiIndex.from_product( [ [ + _id + "Nose", _id + "Left_bhip", _id + "Right_bhip", _id + "Left_fhip", @@ -173,10 +176,20 @@ def test_huddle(pos_dframe, tol_forward, tol_speed, animal_id): tol_speed, animal_id, ) + digging = dig( + pos_dframe.xs("X", level="coords", axis=1, drop_level=True), + pos_dframe.xs("X", level="coords", axis=1, drop_level=True), + tol_speed, + 0.85, + animal_id, + ) assert hudd.dtype == bool + assert digging.dtype == bool assert np.array(hudd).shape[0] == pos_dframe.shape[0] + assert np.array(digging).shape[0] == pos_dframe.shape[0] assert np.sum(np.array(hudd)) <= pos_dframe.shape[0] + assert np.sum(np.array(digging)) <= pos_dframe.shape[0] @settings(max_examples=10, deadline=None, suppress_health_check=[HealthCheck.too_slow])