diff --git a/deepof/preprocess.py b/deepof/preprocess.py index c45391c4c52e26b0cc6a5deac9d76100fdd1b2ff..5ba3925250523ab9d7f687e39a69fd2703fbd472 100644 --- a/deepof/preprocess.py +++ b/deepof/preprocess.py @@ -12,6 +12,7 @@ import warnings import networkx as nx from deepof.utils import * +from deepof.visuals import * class project: diff --git a/deepof/utils.py b/deepof/utils.py index 2b8f96927e7c6d40210f36dbe64d34f3141d1a19..fdfeea3451e52a8c025572a46652b4bf0bdc4357 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -16,7 +16,12 @@ from scipy import spatial from scipy import stats from sklearn import mixture from tqdm import tqdm -from typing import Tuple, Any, List, Union +from typing import Tuple, Any, List, Union, Dict, NewType + +# DEFINE CUSTOM ANNOTATED TYPES # + + +TableDict = NewType("TableDict", Any) # QUALITY CONTROL AND PREPROCESSING # @@ -281,7 +286,12 @@ def smooth_mult_trajectory(series: np.array, alpha: float = 0.15) -> np.array: def close_single_contact( - pos_dframe: pd.DataFrame, left: str, right: str, tol: float + pos_dframe: pd.DataFrame, + left: str, + right: str, + tol: float, + arena_abs: int, + arena_rel: int, ) -> np.array: """Returns a boolean array that's True if the specified body parts are closer than tol. @@ -290,13 +300,17 @@ def close_single_contact( to two-animal experiments. - left (string): First member of the potential contact - right (string): Second member of the potential contact - - tol (float) + - tol (float): maximum distance for which a contact is reported + - arena_abs (int): length in mm of the diameter of the real arena + - arena_rel (int): length in pixels of the diameter of the arena in the video Returns: - contact_array (np.array): True if the distance between the two specified points is less than tol, False otherwise""" - close_contact = np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) < tol + close_contact = ( + np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs + ) / arena_rel < tol return close_contact @@ -308,6 +322,8 @@ def close_double_contact( right1: str, right2: str, tol: float, + arena_abs: int, + arena_rel: int, rev: bool = False, ) -> np.array: """Returns a boolean array that's True if the specified body parts are closer than tol. @@ -319,7 +335,10 @@ def close_double_contact( - left2 (string): Second contact point of animal 1 - right1 (string): First contact point of animal 2 - right2 (string): Second contact point of animal 2 - - tol (float) + - tol (float): maximum distance for which a contact is reported + - arena_abs (int): length in mm of the diameter of the real arena + - arena_rel (int): length in pixels of the diameter of the arena in the video + - rev (bool): reverses the default behaviour (nose2tail contact for both mice) Returns: - double_contact (np.array): True if the distance between the two specified points @@ -327,13 +346,25 @@ def close_double_contact( if rev: double_contact = ( - np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) < tol - ) & (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) < tol) + (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs) + / arena_rel + < tol + ) & ( + (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs) + / arena_rel + < tol + ) else: double_contact = ( - np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) < tol - ) & (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) < tol) + (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs) + / arena_rel + < tol + ) & ( + (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs) + / arena_rel + < tol + ) return double_contact @@ -344,7 +375,7 @@ def recognize_arena( path: str = ".", recoglimit: int = 1, arena_type: str = "circular", -) -> np.array: +) -> Tuple[np.array, int, int]: """Returns numpy.array with information about the arena recognised from the first frames of the video. WARNING: estimates won't be reliable if the camera moves along the video. @@ -357,7 +388,9 @@ def recognize_arena( Returns: - arena (np.array): 1D-array containing information about the arena. - "circular" (3-element-array) -> x-y position of the center and the radius""" + "circular" (3-element-array) -> x-y position of the center and the radius + - h (int): height of the video in pixels + - w (int): width of the video in pixels""" cap = cv2.VideoCapture(os.path.join(path, videos[vid_index])) @@ -380,7 +413,7 @@ def recognize_arena( fnum += 1 - return arena + return arena, h, w def circular_arena_recognition(frame: np.array) -> np.array: @@ -493,7 +526,9 @@ def rolling_speed( return speeds -def huddle(pos_dframe: pd.DataFrame, tol_forward: float, tol_spine: float) -> np.array: +def huddle( + pos_dframe: pd.DataFrame, tol_forward: float, tol_spine: float, tol_speed: float +) -> np.array: """Returns true when the mouse is huddling using simple rules. (!!!) Designed to work with deepof's default DLC mice models; not guaranteed to work otherwise. @@ -660,312 +695,6 @@ def single_behaviour_analysis( return return_list -# MAIN BEHAVIOUR TAGGING FUNCTION # - - -def tag_video( - Tracks, - Videos, - Track_dict, - Distance_dict, - Like_QC_dict, - vid_index, - show=False, - save=False, - fps=25.0, - speedpause=50, - framelimit=np.inf, - recoglimit=1, - path="./", - classifiers={}, -): - """Outputs a dataframe with the motives registered per frame. If mp4==True, outputs a video in mp4 format""" - - vid_name = re.findall("(.*?)_", Tracks[vid_index])[0] - - cap = cv2.VideoCapture(path + Videos[vid_index]) - dframe = Track_dict[vid_name] - h, w = None, None - bspeed, wspeed = None, None - - # Disctionary with motives per frame - tagdict = { - func: np.zeros(dframe.shape[0]) - for func in [ - "nose2nose", - "bnose2tail", - "wnose2tail", - "sidebyside", - "sidereside", - "bclimbwall", - "wclimbwall", - "bspeed", - "wspeed", - "bhuddle", - "whuddle", - "bfollowing", - "wfollowing", - ] - } - - # Keep track of the frame number, to align with the tracking data - fnum = 0 - if save: - writer = None - - # Loop over the first frames in the video to get resolution and center of the arena - while cap.isOpened() and fnum < recoglimit: - ret, frame = cap.read() - # if frame is read correctly ret is True - if not ret: - print("Can't receive frame (stream end?). Exiting ...") - break - - # Detect arena and extract positions - arena = circular_arena_recognition(frame)[0] - if h == None and w == None: - h, w = frame.shape[0], frame.shape[1] - - fnum += 1 - - # Define behaviours that can be computed on the fly from the distance matrix - tagdict["nose2nose"] = smooth_boolean_array( - Distance_dict[vid_name][("B_Nose", "W_Nose")] < 15 - ) - tagdict["bnose2tail"] = smooth_boolean_array( - Distance_dict[vid_name][("B_Nose", "W_Tail_base")] < 15 - ) - tagdict["wnose2tail"] = smooth_boolean_array( - Distance_dict[vid_name][("B_Tail_base", "W_Nose")] < 15 - ) - tagdict["sidebyside"] = smooth_boolean_array( - (Distance_dict[vid_name][("B_Nose", "W_Nose")] < 40) - & (Distance_dict[vid_name][("B_Tail_base", "W_Tail_base")] < 40) - ) - tagdict["sidereside"] = smooth_boolean_array( - (Distance_dict[vid_name][("B_Nose", "W_Tail_base")] < 40) - & (Distance_dict[vid_name][("B_Tail_base", "W_Nose")] < 40) - ) - - B_mouse_X = np.array( - Distance_dict[vid_name][ - [j for j in Distance_dict[vid_name].keys() if "B_" in j[0] and "B_" in j[1]] - ] - ) - W_mouse_X = np.array( - Distance_dict[vid_name][ - [j for j in Distance_dict[vid_name].keys() if "W_" in j[0] and "W_" in j[1]] - ] - ) - - tagdict["bhuddle"] = smooth_boolean_array(classifiers["huddle"].predict(B_mouse_X)) - tagdict["whuddle"] = smooth_boolean_array(classifiers["huddle"].predict(W_mouse_X)) - - tagdict["bclimbwall"] = smooth_boolean_array( - pd.Series( - ( - spatial.distance.cdist( - np.array(dframe["B_Nose"]), np.array([arena[:2]]) - ) - > (w / 200 + arena[2]) - ).reshape(dframe.shape[0]), - index=dframe.index, - ) - ) - tagdict["wclimbwall"] = smooth_boolean_array( - pd.Series( - ( - spatial.distance.cdist( - np.array(dframe["W_Nose"]), np.array([arena[:2]]) - ) - > (w / 200 + arena[2]) - ).reshape(dframe.shape[0]), - index=dframe.index, - ) - ) - tagdict["bfollowing"] = smooth_boolean_array( - following_path( - Distance_dict[vid_name], - dframe, - follower="B", - followed="W", - frames=20, - tol=20, - ) - ) - tagdict["wfollowing"] = smooth_boolean_array( - following_path( - Distance_dict[vid_name], - dframe, - follower="W", - followed="B", - frames=20, - tol=20, - ) - ) - - # Compute speed on a rolling window - tagdict["bspeed"] = rolling_speed(dframe["B_Center"], window=speedpause) - tagdict["wspeed"] = rolling_speed(dframe["W_Center"], window=speedpause) - - if any([show, save]): - # Loop over the frames in the video - pbar = tqdm(total=min(dframe.shape[0] - recoglimit, framelimit)) - while cap.isOpened() and fnum < framelimit: - - ret, frame = cap.read() - # if frame is read correctly ret is True - if not ret: - print("Can't receive frame (stream end?). Exiting ...") - break - - font = cv2.FONT_HERSHEY_COMPLEX_SMALL - - if Like_QC_dict[vid_name][fnum]: - - # Extract positions - pos_dict = { - i: np.array([dframe[i]["x"][fnum], dframe[i]["y"][fnum]]) - for i in dframe.columns.levels[0] - if i != "Like_QC" - } - - if h == None and w == None: - h, w = frame.shape[0], frame.shape[1] - - # Label positions - downleft = (int(w * 0.3 / 10), int(h / 1.05)) - downright = (int(w * 6.5 / 10), int(h / 1.05)) - upleft = (int(w * 0.3 / 10), int(h / 20)) - upright = (int(w * 6.3 / 10), int(h / 20)) - - # Display all annotations in the output video - if tagdict["nose2nose"][fnum] and not tagdict["sidebyside"][fnum]: - cv2.putText( - frame, - "Nose-Nose", - (downleft if bspeed > wspeed else downright), - font, - 1, - (255, 255, 255), - 2, - ) - if tagdict["bnose2tail"][fnum] and not tagdict["sidereside"][fnum]: - cv2.putText( - frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2 - ) - if tagdict["wnose2tail"][fnum] and not tagdict["sidereside"][fnum]: - cv2.putText( - frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2 - ) - if tagdict["sidebyside"][fnum]: - cv2.putText( - frame, - "Side-side", - (downleft if bspeed > wspeed else downright), - font, - 1, - (255, 255, 255), - 2, - ) - if tagdict["sidereside"][fnum]: - cv2.putText( - frame, - "Side-Rside", - (downleft if bspeed > wspeed else downright), - font, - 1, - (255, 255, 255), - 2, - ) - if tagdict["bclimbwall"][fnum]: - cv2.putText( - frame, "Climbing", downleft, font, 1, (255, 255, 255), 2 - ) - if tagdict["wclimbwall"][fnum]: - cv2.putText( - frame, "Climbing", downright, font, 1, (255, 255, 255), 2 - ) - if tagdict["bhuddle"][fnum] and not tagdict["bclimbwall"][fnum]: - cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2) - if tagdict["whuddle"][fnum] and not tagdict["wclimbwall"][fnum]: - cv2.putText(frame, "huddle", downright, font, 1, (255, 255, 255), 2) - if tagdict["bfollowing"][fnum] and not tagdict["bclimbwall"][fnum]: - cv2.putText( - frame, - "*f", - (int(w * 0.3 / 10), int(h / 10)), - font, - 1, - ((150, 150, 255) if wspeed > bspeed else (150, 255, 150)), - 2, - ) - if tagdict["wfollowing"][fnum] and not tagdict["wclimbwall"][fnum]: - cv2.putText( - frame, - "*f", - (int(w * 6.3 / 10), int(h / 10)), - font, - 1, - ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)), - 2, - ) - - if (bspeed == None and wspeed == None) or fnum % speedpause == 0: - bspeed = tagdict["bspeed"][fnum] - wspeed = tagdict["wspeed"][fnum] - - cv2.putText( - frame, - "W: " + str(np.round(wspeed, 2)) + " mmpf", - (upright[0] - 20, upright[1]), - font, - 1, - ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)), - 2, - ) - cv2.putText( - frame, - "B: " + str(np.round(bspeed, 2)) + " mmpf", - upleft, - font, - 1, - ((150, 150, 255) if bspeed < wspeed else (150, 255, 150)), - 2, - ) - - if show: - cv2.imshow("frame", frame) - - if save: - - if writer is None: - # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file. - # Define the FPS. Also frame size is passed. - writer = cv2.VideoWriter() - writer.open( - re.findall("(.*?)_", Tracks[vid_index])[0] + "_tagged.avi", - cv2.VideoWriter_fourcc(*"MJPG"), - fps, - (frame.shape[1], frame.shape[0]), - True, - ) - writer.write(frame) - - if cv2.waitKey(1) == ord("q"): - break - - pbar.update(1) - fnum += 1 - - cap.release() - cv2.destroyAllWindows() - - tagdf = pd.DataFrame(tagdict) - - return tagdf, arena - - def max_behaviour( behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False ) -> np.array: @@ -1136,5 +865,349 @@ def cluster_transition_matrix( return trans_normed +# MAIN BEHAVIOUR TAGGING FUNCTION # + + +def rule_based_tagging( + tracks: List, + videos: List, + table_dict: TableDict, + vid_index: int, + arena_abs: int, + animal_ids: List = None, + show: bool = False, + save: bool = False, + fps: float = 25.0, + speed_pause: int = 50, + frame_limit: float = np.inf, + recog_limit: int = 1, + path: str = os.path.join("./"), + arena_type: str = "circular", + classifiers: Dict = None, +) -> Tuple[pd.DataFrame, Any]: + """Outputs a dataframe with the motives registered per frame. If mp4==True, outputs a video in mp4 format""" + + # noinspection PyProtectedMember + assert table_dict._type == "merged", ( + "Table_dict must be of merged type, " + "and contain at least position, speed and distance information" + ) + + vid_name = re.findall("(.*?)_", tracks[vid_index])[0] + + dframe = table_dict[vid_name] + arena, h, w = recognize_arena(videos, vid_index, path, recog_limit, arena_type) + + # Dictionary with motives per frame + behavioural_tags = [] + if animal_ids: + behavioural_tags.append(["nose2nose", "sidebyside", "sidereside"]) + for _id in animal_ids: + for behaviour in [ + "_nose2tail", + "_climbing", + "_huddle", + "_following", + "_speed", + ]: + behavioural_tags.append(_id + behaviour) + + else: + behavioural_tags.append(["huddle", "climbing", "speed"]) + + tag_dict = {tag: np.zeros(dframe.shape[0]) for tag in behavioural_tags} + + if animal_ids: + # Define behaviours that can be computed on the fly from the distance matrix + tag_dict["nose2nose"] = smooth_boolean_array( + close_single_contact( + dframe, + animal_ids[0] + "_Nose", + animal_ids[1] + "_Nose", + 15.0, + arena_abs, + arena[2], + ) + ) + tag_dict[animal_ids[0] + "_nose2tail"] = smooth_boolean_array( + close_single_contact( + dframe, + animal_ids[0] + "_Nose", + animal_ids[1] + "_Tail_base", + 15.0, + arena_abs, + arena[2], + ) + ) + tag_dict[animal_ids[1] + "_nose2tail"] = smooth_boolean_array( + close_single_contact( + dframe, + animal_ids[1] + "_Nose", + animal_ids[0] + "_Tail_base", + 15.0, + arena_abs, + arena[2], + ) + ) + tag_dict["sidebyside"] = smooth_boolean_array( + close_double_contact( + dframe, + animal_ids[0] + "_Nose", + animal_ids[0] + "_Tail_base", + animal_ids[1] + "_Nose", + animal_ids[1] + "_Tail_base", + 15.0, + rev=False, + arena_abs=arena_abs, + arena_rel=arena[2], + ) + ) + tag_dict["sidereside"] = smooth_boolean_array( + close_double_contact( + dframe, + animal_ids[0] + "_Nose", + animal_ids[0] + "_Tail_base", + animal_ids[1] + "_Nose", + animal_ids[1] + "_Tail_base", + 15.0, + rev=True, + arena_abs=arena_abs, + arena_rel=arena[2], + ) + ) + for _id in animal_ids: + tag_dict[_id + "_following"] = smooth_boolean_array( + following_path( + dframe[vid_name], + dframe, + follower=_id, + followed=[i for i in animal_ids if i != _id][0], + frames=20, + tol=20, + ) + ) + tag_dict[_id + "_climbwall"] = smooth_boolean_array( + pd.Series( + ( + spatial.distance.cdist( + np.array(dframe[_id + "_Nose"]), np.array([arena[:2]]) + ) + > (w / 200 + arena[2]) + ).reshape(dframe.shape[0]), + index=dframe.index, + ) + ) + tag_dict[_id + "speed"] = rolling_speed( + dframe[_id + "_Center"], window=speed_pause + ) + + else: + tag_dict["climbwall"] = smooth_boolean_array( + pd.Series( + ( + spatial.distance.cdist( + np.array(dframe["Nose"]), np.array([arena[:2]]) + ) + > (w / 200 + arena[2]) + ).reshape(dframe.shape[0]), + index=dframe.index, + ) + ) + tag_dict["speed"] = rolling_speed(dframe["Center"], window=speed_pause) + + if "huddle" in classifiers: + mouse_X = { + _id: np.array( + dframe[vid_name][ + [ + j + for j in dframe[vid_name].keys() + if (len(j) == 2 and _id in j[0] and _id in j[1]) + ] + ] + ) + for _id in animal_ids + } + for _id in animal_ids: + tag_dict[_id + "_huddle"] = smooth_boolean_array( + classifiers["huddle"].predict(mouse_X[_id]) + ) + else: + try: + for _id in animal_ids: + tag_dict[_id + "_huddle"] = smooth_boolean_array( + huddle(dframe, 25, 25, 5) + ) + except TypeError: + tag_dict["huddle"] = smooth_boolean_array(huddle(dframe, 25, 25, 5)) + + # if any([show, save]): + # cap = cv2.VideoCapture(path + videos[vid_index]) + + # # Keep track of the frame number, to align with the tracking data + # fnum = 0 + # if save: + # writer = None + + # # Loop over the frames in the video + # pbar = tqdm(total=min(dframe.shape[0] - recog_limit, frame_limit)) + # while cap.isOpened() and fnum < frame_limit: + # + # ret, frame = cap.read() + # # if frame is read correctly ret is True + # if not ret: + # print("Can't receive frame (stream end?). Exiting ...") + # break + # + # font = cv2.FONT_HERSHEY_COMPLEX_SMALL + # + # if like_qc_dict[vid_name][fnum]: + # + # # Extract positions + # pos_dict = { + # i: np.array([dframe[i]["x"][fnum], dframe[i]["y"][fnum]]) + # for i in dframe.columns.levels[0] + # if i != "Like_QC" + # } + # + # if h is None and w is None: + # h, w = frame.shape[0], frame.shape[1] + # + # # Label positions + # downleft = (int(w * 0.3 / 10), int(h / 1.05)) + # downright = (int(w * 6.5 / 10), int(h / 1.05)) + # upleft = (int(w * 0.3 / 10), int(h / 20)) + # upright = (int(w * 6.3 / 10), int(h / 20)) + # + # # Display all annotations in the output video + # if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]: + # cv2.putText( + # frame, + # "Nose-Nose", + # (downleft if bspeed > wspeed else downright), + # font, + # 1, + # (255, 255, 255), + # 2, + # ) + # if tag_dict["bnose2tail"][fnum] and not tag_dict["sidereside"][fnum]: + # cv2.putText( + # frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2 + # ) + # if tag_dict["wnose2tail"][fnum] and not tag_dict["sidereside"][fnum]: + # cv2.putText( + # frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2 + # ) + # if tag_dict["sidebyside"][fnum]: + # cv2.putText( + # frame, + # "Side-side", + # (downleft if bspeed > wspeed else downright), + # font, + # 1, + # (255, 255, 255), + # 2, + # ) + # if tag_dict["sidereside"][fnum]: + # cv2.putText( + # frame, + # "Side-Rside", + # (downleft if bspeed > wspeed else downright), + # font, + # 1, + # (255, 255, 255), + # 2, + # ) + # if tag_dict["bclimbwall"][fnum]: + # cv2.putText( + # frame, "Climbing", downleft, font, 1, (255, 255, 255), 2 + # ) + # if tag_dict["wclimbwall"][fnum]: + # cv2.putText( + # frame, "Climbing", downright, font, 1, (255, 255, 255), 2 + # ) + # if tag_dict["bhuddle"][fnum] and not tag_dict["bclimbwall"][fnum]: + # cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2) + # if tag_dict["whuddle"][fnum] and not tag_dict["wclimbwall"][fnum]: + # cv2.putText(frame, "huddle", downright, font, 1, (255, 255, 255), 2) + # if tag_dict["bfollowing"][fnum] and not tag_dict["bclimbwall"][fnum]: + # cv2.putText( + # frame, + # "*f", + # (int(w * 0.3 / 10), int(h / 10)), + # font, + # 1, + # ((150, 150, 255) if wspeed > bspeed else (150, 255, 150)), + # 2, + # ) + # if tag_dict["wfollowing"][fnum] and not tag_dict["wclimbwall"][fnum]: + # cv2.putText( + # frame, + # "*f", + # (int(w * 6.3 / 10), int(h / 10)), + # font, + # 1, + # ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)), + # 2, + # ) + # + # if (bspeed == None and wspeed == None) or fnum % speed_pause == 0: + # bspeed = tag_dict["bspeed"][fnum] + # wspeed = tag_dict["wspeed"][fnum] + # + # cv2.putText( + # frame, + # "W: " + str(np.round(wspeed, 2)) + " mmpf", + # (upright[0] - 20, upright[1]), + # font, + # 1, + # ((150, 150, 255) if wspeed < bspeed else (150, 255, 150)), + # 2, + # ) + # cv2.putText( + # frame, + # "B: " + str(np.round(bspeed, 2)) + " mmpf", + # upleft, + # font, + # 1, + # ((150, 150, 255) if bspeed < wspeed else (150, 255, 150)), + # 2, + # ) + # + # if show: + # cv2.imshow("frame", frame) + # + # if save: + # + # if writer is None: + # # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file. + # # Define the FPS. Also frame size is passed. + # writer = cv2.VideoWriter() + # writer.open( + # re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi", + # cv2.VideoWriter_fourcc(*"MJPG"), + # fps, + # (frame.shape[1], frame.shape[0]), + # True, + # ) + # writer.write(frame) + # + # if cv2.waitKey(1) == ord("q"): + # break + # + # pbar.update(1) + # fnum += 1 + + # cap.release() + # cv2.destroyAllWindows() + + tagdf = pd.DataFrame(tag_dict) + + return tagdf, arena + + # TODO: # - Add sequence plot to single_behaviour_analysis (show how the condition varies across a specified time window) +# - Add digging to rule_based_tagging +# - Add center to rule_based_tagging +# - Check for features requested by Joeri diff --git a/deepof/visuals.py b/deepof/visuals.py index bf7dab0bfe19a67a4d0e89bc0a162dd07cc87b02..c45aaea367716060e09441a0ca52b91d621a60a8 100644 --- a/deepof/visuals.py +++ b/deepof/visuals.py @@ -11,7 +11,9 @@ from typing import List, Dict # PLOTTING FUNCTIONS # -def plot_speed(behaviour_dict: dict, treatments: Dict[List]) -> plt.figure: +def plot_speed( + behaviour_dict: Dict[str, pd.DataFrame], treatments: Dict[str, List] +) -> plt.figure: """Plots a histogram with the speed of the specified mouse. Treatments is expected to be a list of lists with mice keys per treatment""" diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdf4f9d1d69d0b3061eb77ad0c4a437745c4387 --- /dev/null +++ b/tests/test_preprocess.py @@ -0,0 +1,12 @@ +# @author lucasmiranda42 + +from hypothesis import given +from hypothesis import HealthCheck +from hypothesis import settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays +from hypothesis.extra.pandas import range_indexes, columns, data_frames +from scipy.spatial import distance +from deepof.utils import * +import deepof.preprocess +import pytest diff --git a/tests/test_utils.py b/tests/test_utils.py index 34739cf872454e9ea10f4b4b6818bd9d88c17988..d7c447ed0cfd17b48bf186d0b4f9a182241d1ab9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,6 +11,7 @@ from deepof.utils import * import deepof.preprocess import pytest + # AUXILIARY FUNCTIONS # @@ -343,7 +344,7 @@ def test_close_single_contact(pos_dframe, tol): [["bpart1", "bpart2"], ["X", "y"]], names=["bodyparts", "coords"], ) pos_dframe.columns = idx - close_contact = close_single_contact(pos_dframe, "bpart1", "bpart2", tol) + close_contact = close_single_contact(pos_dframe, "bpart1", "bpart2", tol, 1, 1) assert close_contact.dtype == bool assert np.array(close_contact).shape[0] <= pos_dframe.shape[0] @@ -375,7 +376,7 @@ def test_close_double_contact(pos_dframe, tol, rev): ) pos_dframe.columns = idx close_contact = close_double_contact( - pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, rev + pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, rev, 1, 1 ) assert close_contact.dtype == bool assert np.array(close_contact).shape[0] <= pos_dframe.shape[0] @@ -730,3 +731,9 @@ def test_cluster_transition_matrix(sampler, autocorrelation, return_graph): assert type(trans) == nx.Graph else: assert type(trans) == np.ndarray + + +@settings(deadline=None) +@given() +def test_rule_based_tagging(): + pass \ No newline at end of file