diff --git a/deepof/preprocess.py b/deepof/preprocess.py index 1ca4da36ffb6e3aaddfe45a91b5eebddaf04cb74..8fc9ff4b2da5723e7379dffefa60420f07be0f2a 100644 --- a/deepof/preprocess.py +++ b/deepof/preprocess.py @@ -708,3 +708,4 @@ def merge_tables(*args): # TODO: # - Generate ragged training array using a metric (acceleration, maybe?) # - Use something like Dynamic Time Warping to put all instances in the same length +# - add rule_based_annotation method to coordinates class!! diff --git a/deepof/utils.py b/deepof/utils.py index 9222ae786e65ecbff33257df6d948f4c739aa2ca..20350dbc2ae453f2d28a103c7c1d00bf8b924de0 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -922,7 +922,7 @@ def rule_based_tagging( animal_ids: List = None, show: bool = False, save: bool = False, - fps: float = 25.0, + fps: float = 0.0, speed_pause: int = 10, frame_limit: float = np.inf, recog_limit: int = 1, @@ -940,26 +940,27 @@ def rule_based_tagging( video displaying the information in real time Parameters: - - tracks (list): - - videos (list): - - coordinates (deepof.preprocessing.coordinates): - - vid_index (int): - - animal_ids (list): - - show (bool): - - save (bool): - - fps (float): - - speed_pause (int): - - frame_limit (float): - - recog_limit (int): - - path (str): - - arena_type (str): - - close_contact_tol (int): - - side_contact_tol (int): - - follow_frames (int): - - follow_tol (int): - - huddle_forward (int): - - huddle_spine (int): - - huddle_speed (int): + - tracks (list): list containing experiment IDs as strings + - videos (list): list of videos to load, in the same order as tracks + - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information + - vid_index (int): index in videos of the experiment to annotate + - animal_ids (list): IDs identifying multiple animals on the arena. None if there's only one + - show (bool): if True, enables the display of the annotated video in a separate window + - save (bool): if True, saves the annotated video to an mp4 file + - fps (float): frames per second of the analysed video. Same as input by default + - speed_pause (int): size of the rolling window to use when computing speeds + - frame_limit (float): limit the number of frames to output. Generates all annotated frames by default + - recog_limit (int): number of frames to use for arena recognition (1 by default) + - path (str): directory in which the experimental data is stored + - arena_type (str): type of the arena used in the experiments. Must be one of 'circular'" + - close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait + - side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait + - follow_frames (int): number of frames during which the following trait is tracked + - follow_tol (int): maximum distance between follower and followed's path during the last follow_frames, + in order to report a detection + - huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection + - huddle_spine (int): maximum average distance between spine body parts to report a huddle detection + - huddle_speed (int): maximum speed to report a huddle detection Returns: - tag_df (pandas.DataFrame): table with traits as columns and frames as rows. Each @@ -1250,11 +1251,13 @@ def rule_based_tagging( writer.open( re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi", cv2.VideoWriter_fourcc(*"MJPG"), - fps, + (fps if fps != 0 else cv2.CAP_PROP_FPS), (frame.shape[1], frame.shape[0]), True, ) + print(cv2.CAP_PROP_FPS) + writer.write(frame) pbar.update(1) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index a092febb4d8ce1d7f8e59e9870c7ddcb13adfc3b..9f70c9fe61941a6934f4d18034ddd7c4dd56b2e6 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -219,7 +219,7 @@ def test_get_table_dicts(nodes, ego, sampler): align=align, ) - assert (type(prep) == np.ndarray or type(prep) == tuple) + assert type(prep) == np.ndarray or type(prep) == tuple if type(prep) == tuple: assert type(prep[0]) == np.ndarray