pose_utils.py 27.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# @author lucasmiranda42
# encoding: utf-8
# module deepof

"""

Functions and general utilities for rule-based pose estimation. See documentation for details

"""

import cv2
lucas_miranda's avatar
lucas_miranda committed
12
import deepof.utils
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import regex as re
import seaborn as sns
from itertools import combinations
from scipy import stats
from typing import Any, List, NewType

Coordinates = NewType("Coordinates", Any)


def close_single_contact(
lucas_miranda's avatar
lucas_miranda committed
27
28
29
30
31
32
    pos_dframe: pd.DataFrame,
    left: str,
    right: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
33
34
35
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

36
37
38
39
40
41
42
43
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left (string): First member of the potential contact
        - right (string): Second member of the potential contact
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
44

45
46
47
    Returns:
        - contact_array (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
48
49

    close_contact = (
lucas_miranda's avatar
lucas_miranda committed
50
51
        np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
    ) / arena_rel < tol
52
53
54
55
56

    return close_contact


def close_double_contact(
lucas_miranda's avatar
lucas_miranda committed
57
58
59
60
61
62
63
64
65
    pos_dframe: pd.DataFrame,
    left1: str,
    left2: str,
    right1: str,
    right2: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
    rev: bool = False,
66
67
68
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

69
70
71
72
73
74
75
76
77
78
79
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left1 (string): First contact point of animal 1
        - left2 (string): Second contact point of animal 1
        - right1 (string): First contact point of animal 2
        - right2 (string): Second contact point of animal 2
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
        - rev (bool): reverses the default behaviour (nose2tail contact for both mice)
80

81
82
83
    Returns:
        - double_contact (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
84
85
86

    if rev:
        double_contact = (
lucas_miranda's avatar
lucas_miranda committed
87
88
89
90
91
92
93
94
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
95
96
97

    else:
        double_contact = (
lucas_miranda's avatar
lucas_miranda committed
98
99
100
101
102
103
104
105
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
106
107
108
109
110

    return double_contact


def climb_wall(
111
112
113
114
115
    arena_type: str,
    arena: np.array,
    pos_dict: pd.DataFrame,
    tol: float,
    nose: str,
lucas_miranda's avatar
lucas_miranda committed
116
    centered_data: bool = False,
117
118
119
) -> np.array:
    """Returns True if the specified mouse is climbing the wall

120
121
122
123
124
125
126
127
128
    Parameters:
        - arena_type (str): arena type; must be one of ['circular']
        - arena (np.array): contains arena location and shape details
        - pos_dict (table_dict): position over time for all videos in a project
        - tol (float): minimum tolerance to report a hit
        - nose (str): indicates the name of the body part representing the nose of
        the selected animal
        - arena_dims (int): indicates radius of the real arena in mm
        - centered_data (bool): indicates whether the input data is centered
129

130
131
132
    Returns:
        - climbing (np.array): boolean array. True if selected animal
        is climbing the walls of the arena"""
133
134
135
136

    nose = pos_dict[nose]

    if arena_type == "circular":
137
        center = np.zeros(2) if centered_data else np.array(arena[:2])
lucas_miranda's avatar
lucas_miranda committed
138
139
        radius = arena[2]
        climbing = np.linalg.norm(nose - center, axis=1) > (radius + tol)
140
141
142
143
144
145
146
147

    else:
        raise NotImplementedError("Supported values for arena_type are ['circular']")

    return climbing


def huddle(
lucas_miranda's avatar
lucas_miranda committed
148
149
150
151
152
153
    pos_dframe: pd.DataFrame,
    speed_dframe: pd.DataFrame,
    tol_forward: float,
    tol_spine: float,
    tol_speed: float,
    animal_id: str = "",
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
) -> np.array:
    """Returns true when the mouse is huddling using simple rules. (!!!) Designed to
    work with deepof's default DLC mice models; not guaranteed to work otherwise.

        Parameters:
            - pos_dframe (pandas.DataFrame): position of body parts over time
            - speed_dframe (pandas.DataFrame): speed of body parts over time
            - tol_forward (float): Maximum tolerated distance between ears and
            forward limbs
            - tol_rear (float): Maximum tolerated average distance between spine
            body parts
            - tol_speed (float): Maximum tolerated speed for the center of the mouse

        Returns:
            hudd (np.array): True if the animal is huddling, False otherwise
169
    """
170
171
172
173
174

    if animal_id != "":
        animal_id += "_"

    forward = (
lucas_miranda's avatar
lucas_miranda committed
175
176
177
178
179
180
181
182
183
184
185
186
        np.linalg.norm(
            pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
            axis=1,
        )
        < tol_forward
    ) & (
        np.linalg.norm(
            pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
            axis=1,
        )
        < tol_forward
    )
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    spine = [
        animal_id + "Spine_1",
        animal_id + "Center",
        animal_id + "Spine_2",
        animal_id + "Tail_base",
    ]
    spine_dists = []
    for comb in range(2):
        spine_dists.append(
            np.linalg.norm(
                pos_dframe[spine[comb]] - pos_dframe[spine[comb + 1]], axis=1
            )
        )
    spine = np.mean(spine_dists) < tol_spine
    speed = speed_dframe[animal_id + "Center"] < tol_speed
    hudd = forward & spine & speed

    return hudd


def following_path(
lucas_miranda's avatar
lucas_miranda committed
209
210
211
212
213
214
    distance_dframe: pd.DataFrame,
    position_dframe: pd.DataFrame,
    follower: str,
    followed: str,
    frames: int = 20,
    tol: float = 0,
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
) -> np.array:
    """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
    followed has walked over the last specified number of frames

        Parameters:
            - distance_dframe (pandas.DataFrame): distances between bodyparts; generated by the preprocess module
            - position_dframe (pandas.DataFrame): position of bodyparts; generated by the preprocess module
            - follower (str) identifier for the animal who's following
            - followed (str) identifier for the animal who's followed
            - frames (int) frames in which to track whether the process consistently occurs,
            - tol (float) Maximum distance for which True is returned

        Returns:
            - follow (np.array): boolean sequence, True if conditions are fulfilled, False otherwise"""

    # Check that follower is close enough to the path that followed has passed though in the last frames
    shift_dict = {
        i: position_dframe[followed + "_Tail_base"].shift(i) for i in range(frames)
    }
    dist_df = pd.DataFrame(
        {
            i: np.linalg.norm(
                position_dframe[follower + "_Nose"] - shift_dict[i], axis=1
            )
            for i in range(frames)
        }
    )

    # Check that the animals are oriented follower's nose -> followed's tail
    right_orient1 = (
lucas_miranda's avatar
lucas_miranda committed
245
246
247
248
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[
            tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
        ]
249
250
251
    )

    right_orient2 = (
lucas_miranda's avatar
lucas_miranda committed
252
253
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
254
255
256
    )

    follow = np.all(
257
258
        np.array([(dist_df.min(axis=1) < tol), right_orient1, right_orient2]),
        axis=0,
259
260
261
262
263
264
    )

    return follow


def single_behaviour_analysis(
lucas_miranda's avatar
lucas_miranda committed
265
266
267
268
269
270
271
    behaviour_name: str,
    treatment_dict: dict,
    behavioural_dict: dict,
    plot: int = 0,
    stat_tests: bool = True,
    save: str = None,
    ylim: float = None,
272
273
) -> list:
    """Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
274
    with the actual tags, outputs a box plot and a series of significance tests amongst the groups
275

276
277
278
279
280
281
282
283
     Parameters:
         - behaviour_name (str): name of the behavioural trait to analize
         - treatment_dict (dict): dictionary containing video names as keys and experimental conditions as values
         - behavioural_dict (dict): tagged dictionary containing video names as keys and annotations as values
         - plot (int): Silent if 0; otherwise, indicates the dpi of the figure to plot
         - stat_tests (bool): performs FDR corrected Mann-U non-parametric tests among all groups if True
         - save (str): Saves the produced figure to the specified file
         - ylim (float): y-limit for the boxplot. Ignored if plot == False
284

285
286
287
     Returns:
         - beh_dict (dict): dictionary containing experimental conditions as keys and video names as values
         - stat_dict (dict): dictionary containing condition pairs as keys and stat results as values"""
288
289
290
291
292

    beh_dict = {condition: [] for condition in treatment_dict.keys()}

    for condition in beh_dict.keys():
        for ind in treatment_dict[condition]:
293
294
            beh_dict[condition] += np.sum(behavioural_dict[ind][behaviour_name]) / len(
                behavioural_dict[ind][behaviour_name]
295
296
297
298
299
300
301
302
303
            )

    return_list = [beh_dict]

    if plot > 0:

        fig, ax = plt.subplots(dpi=plot)

        sns.boxplot(
304
305
306
307
            x=list(beh_dict.keys()),
            y=list(beh_dict.values()),
            orient="vertical",
            ax=ax,
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        )

        ax.set_title("{} across groups".format(behaviour_name))
        ax.set_ylabel("Proportion of frames")

        if ylim is not None:
            ax.set_ylim(ylim)

        if save is not None:  # pragma: no cover
            plt.savefig(save)

        return_list.append(fig)

    if stat_tests:
        stat_dict = {}
        for i in combinations(treatment_dict.keys(), 2):
            # Solves issue with automatically generated examples
lucas_miranda's avatar
lucas_miranda committed
325
326
327
328
329
330
331
332
            if np.any(
                np.array(
                    [
                        beh_dict[i[0]] == beh_dict[i[1]],
                        np.var(beh_dict[i[0]]) == 0,
                        np.var(beh_dict[i[1]]) == 0,
                    ]
                )
333
334
335
336
337
338
339
340
341
342
343
344
            ):
                stat_dict[i] = "Identical sources. Couldn't run"
            else:
                stat_dict[i] = stats.mannwhitneyu(
                    beh_dict[i[0]], beh_dict[i[1]], alternative="two-sided"
                )
        return_list.append(stat_dict)

    return return_list


def max_behaviour(
lucas_miranda's avatar
lucas_miranda committed
345
    behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
346
347
348
) -> np.array:
    """Returns the most frequent behaviour in a window of window_size frames

349
350
351
352
353
354
    Parameters:
            - behaviour_dframe (pd.DataFrame): boolean matrix containing occurrence
            of tagged behaviours per frame in the video
            - window_size (int): size of the window to use when computing
            the maximum behaviour per time slot
            - stepped (bool): sliding windows don't overlap if True. False by default
355

356
357
358
    Returns:
        - max_array (np.array): string array with the most common behaviour per instance
        of the sliding window"""
359
360
361
362
363
364
365
366
367
368
369
370

    speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()]

    behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).astype("float")
    win_array = behaviour_dframe.rolling(window_size, center=True).sum()
    if stepped:
        win_array = win_array[::window_size]
    max_array = win_array[1:].idxmax(axis=1)

    return np.array(max_array)


371
# noinspection PyDefaultArgument
lucas_miranda's avatar
lucas_miranda committed
372
373
374
def get_hparameters(hparams: dict = {}) -> dict:
    """Returns the most frequent behaviour in a window of window_size frames

375
376
    Parameters:
        - hparams (dict): dictionary containing hyperparameters to overwrite
lucas_miranda's avatar
lucas_miranda committed
377

378
379
380
    Returns:
        - defaults (dict): dictionary with overwriten parameters. Those not
        specified in the input retain their default values"""
lucas_miranda's avatar
lucas_miranda committed
381
382

    defaults = {
383
        "speed_pause": 3,
lucas_miranda's avatar
lucas_miranda committed
384
385
386
387
388
389
        "close_contact_tol": 15,
        "side_contact_tol": 15,
        "follow_frames": 20,
        "follow_tol": 20,
        "huddle_forward": 15,
        "huddle_spine": 10,
lucas_miranda's avatar
lucas_miranda committed
390
        "huddle_speed": 0.1,
lucas_miranda's avatar
lucas_miranda committed
391
        "fps": 24,
lucas_miranda's avatar
lucas_miranda committed
392
    }
393

lucas_miranda's avatar
lucas_miranda committed
394
395
    for k, v in hparams.items():
        defaults[k] = v
396

lucas_miranda's avatar
lucas_miranda committed
397
398
399
    return defaults


400
401
402
403
# noinspection PyDefaultArgument
def frame_corners(w, h, corners: dict = {}):
    """Returns a dictionary with the corner positions of the video frame

404
405
406
407
    Parameters:
        - w (int): width of the frame in pixels
        - h (int): height of the frame in pixels
        - corners (dict): dictionary containing corners to overwrite
408

409
410
411
    Returns:
        - defaults (dict): dictionary with overwriten parameters. Those not
        specified in the input retain their default values"""
412
413
414
415
416
417
418
419
420
421
422
423
424
425

    defaults = {
        "downleft": (int(w * 0.3 / 10), int(h / 1.05)),
        "downright": (int(w * 6.5 / 10), int(h / 1.05)),
        "upleft": (int(w * 0.3 / 10), int(h / 20)),
        "upright": (int(w * 6.3 / 10), int(h / 20)),
    }

    for k, v in corners.items():
        defaults[k] = v

    return defaults


426
# noinspection PyDefaultArgument,PyProtectedMember
427
def rule_based_tagging(
lucas_miranda's avatar
lucas_miranda committed
428
429
430
    tracks: List,
    videos: List,
    coordinates: Coordinates,
431
    coords: Any,
432
    dists: Any,
433
    speeds: Any,
lucas_miranda's avatar
lucas_miranda committed
434
    vid_index: int,
435
    arena_type: str,
lucas_miranda's avatar
lucas_miranda committed
436
437
438
    recog_limit: int = 1,
    path: str = os.path.join("."),
    hparams: dict = {},
439
440
441
442
443
444
445
446
) -> pd.DataFrame:
    """Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
    video displaying the information in real time

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
        - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
447
        - coords (deepof.preprocessing.table_dict): table_dict with already processed coordinates
448
        - dists (deepof.preprocessing.table_dict): table_dict with already processed distances
449
        - speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds
450
        - vid_index (int): index in videos of the experiment to annotate
lucas_miranda's avatar
lucas_miranda committed
451
        - path (str): directory in which the experimental data is stored
452
        - recog_limit (int): number of frames to use for arena recognition (1 by default)
lucas_miranda's avatar
lucas_miranda committed
453
454
455
456
457
458
459
460
461
462
463
        - hparams (dict): dictionary to overwrite the default values of the hyperparameters of the functions
        that the rule-based pose estimation utilizes. Values can be:
            - speed_pause (int): size of the rolling window to use when computing speeds
            - close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - follow_frames (int): number of frames during which the following trait is tracked
            - follow_tol (int): maximum distance between follower and followed's path during the last follow_frames,
            in order to report a detection
            - huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection
            - huddle_spine (int): maximum average distance between spine body parts to report a huddle detection
            - huddle_speed (int): maximum speed to report a huddle detection
464
465
466
467
468

    Returns:
        - tag_df (pandas.DataFrame): table with traits as columns and frames as rows. Each
        value is a boolean indicating trait detection at a given time"""

lucas_miranda's avatar
lucas_miranda committed
469
    hparams = get_hparameters(hparams)
470
    animal_ids = coordinates._animal_ids
471
    undercond = "_" if len(animal_ids) > 1 else ""
lucas_miranda's avatar
lucas_miranda committed
472

473
    try:
474
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
475
476
    except IndexError:
        vid_name = tracks[vid_index]
477

478
    coords = coords[vid_name]
479
    dists = dists[vid_name]
480
    speeds = speeds[vid_name]
481
482
    arena_abs = coordinates.get_arenas[1][0]
    arena, h, w = deepof.utils.recognize_arena(
483
        videos, vid_index, path, recog_limit, coordinates._arena
484
485
486
487
488
    )

    # Dictionary with motives per frame
    tag_dict = {}

489
490
491
    def onebyone_contact(bparts: List):
        """Returns a smooth boolean array with 1to1 contacts between two mice"""
        nonlocal coords, animal_ids, hparams, arena_abs, arena
492

493
        return deepof.utils.smooth_boolean_array(
494
495
            close_single_contact(
                coords,
496
497
                animal_ids[0] + bparts[0],
                animal_ids[1] + bparts[-1],
lucas_miranda's avatar
lucas_miranda committed
498
                hparams["close_contact_tol"],
499
500
501
502
                arena_abs,
                arena[2],
            )
        )
503
504
505
506
507
508

    def twobytwo_contact(rev):
        """Returns a smooth boolean array with side by side contacts between two mice"""

        nonlocal coords, animal_ids, hparams, arena_abs, arena
        return deepof.utils.smooth_boolean_array(
509
510
511
512
513
514
            close_double_contact(
                coords,
                animal_ids[0] + "_Nose",
                animal_ids[0] + "_Tail_base",
                animal_ids[1] + "_Nose",
                animal_ids[1] + "_Tail_base",
lucas_miranda's avatar
lucas_miranda committed
515
                hparams["side_contact_tol"],
516
                rev=rev,
517
518
519
520
                arena_abs=arena_abs,
                arena_rel=arena[2],
            )
        )
521

522
    if len(animal_ids) == 2:
523
524
525
526
527
528
529
530
        # Define behaviours that can be computed on the fly from the distance matrix
        tag_dict["nose2nose"] = onebyone_contact(bparts=["_Nose"])

        tag_dict["sidebyside"] = twobytwo_contact(rev=False)

        tag_dict["sidereside"] = twobytwo_contact(rev=True)

        for i, _id in enumerate(animal_ids):
531
532
533
            tag_dict[_id + "_nose2tail"] = onebyone_contact(
                bparts=["_Nose", "_Tail_base"]
            )
534

535
536
537
        for _id in animal_ids:
            tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
                following_path(
538
                    dists,
539
540
541
                    coords,
                    follower=_id,
                    followed=[i for i in animal_ids if i != _id][0],
lucas_miranda's avatar
lucas_miranda committed
542
543
                    frames=hparams["follow_frames"],
                    tol=hparams["follow_tol"],
544
545
546
                )
            )

547
548
    for _id in animal_ids:
        tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
549
            climb_wall(arena_type, arena, coords, w / 100, _id + undercond + "Nose")
550
        )
551
552
        tag_dict[_id + undercond + "speed"] = speeds[_id + undercond + "Center"]
        tag_dict[_id + undercond + "huddle"] = deepof.utils.smooth_boolean_array(
lucas_miranda's avatar
lucas_miranda committed
553
554
555
556
557
558
            huddle(
                coords,
                speeds,
                hparams["huddle_forward"],
                hparams["huddle_spine"],
                hparams["huddle_speed"],
559
                animal_id=_id,
lucas_miranda's avatar
lucas_miranda committed
560
            )
561
562
        )

563
564
565
566
567
    tag_df = pd.DataFrame(tag_dict)

    return tag_df


lucas_miranda's avatar
lucas_miranda committed
568
569
570
571
572
573
574
575
def tag_rulebased_frames(
    frame,
    font,
    frame_speeds,
    animal_ids,
    corners,
    tag_dict,
    fnum,
lucas_miranda's avatar
lucas_miranda committed
576
    dims,
lucas_miranda's avatar
lucas_miranda committed
577
578
579
    undercond,
    hparams,
):
lucas_miranda's avatar
lucas_miranda committed
580
581
582
583
584
    """Helper function for rule_based_video. Annotates a fiven frame with on-screen information
    about the recognised patterns"""

    w, h = dims

lucas_miranda's avatar
lucas_miranda committed
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    def write_on_frame(text, pos, col=(255, 255, 255)):
        """Partial closure over cv2.putText to avoid code repetition"""
        return cv2.putText(frame, text, pos, font, 1, col, 2)

    def conditional_pos():
        """Returns a position depending on a condition"""
        if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
            return corners["downleft"]
        else:
            return corners["downright"]

    def conditional_col(cond=None):
        """Returns a colour depending on a condition"""
        if cond is None:
            cond = frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
        if cond:
            return 150, 150, 255
        else:
            return 150, 255, 150

605
606
607
608
609
610
611
    zipped_pos = list(
        zip(
            animal_ids,
            [corners["downleft"], corners["downright"]],
            [corners["upleft"], corners["upright"]],
        )
    )
lucas_miranda's avatar
lucas_miranda committed
612
613

    if len(animal_ids) > 1:
614

lucas_miranda's avatar
lucas_miranda committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
            write_on_frame("Nose-Nose", conditional_pos())
        if (
            tag_dict[animal_ids[0] + "_nose2tail"][fnum]
            and not tag_dict["sidereside"][fnum]
        ):
            write_on_frame("Nose-Tail", corners["downleft"])
        if (
            tag_dict[animal_ids[1] + "_nose2tail"][fnum]
            and not tag_dict["sidereside"][fnum]
        ):
            write_on_frame("Nose-Tail", corners["downright"])
        if tag_dict["sidebyside"][fnum]:
            write_on_frame(
629
630
                "Side-side",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
631
632
633
            )
        if tag_dict["sidereside"][fnum]:
            write_on_frame(
634
635
                "Side-Rside",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
636
637
638
639
640
641
642
            )
        for _id, down_pos, up_pos in zipped_pos:
            if (
                tag_dict[_id + "_following"][fnum]
                and not tag_dict[_id + "_climbing"][fnum]
            ):
                write_on_frame(
643
644
645
                    "*f",
                    (int(w * 0.3 / 10), int(h / 10)),
                    conditional_col(),
lucas_miranda's avatar
lucas_miranda committed
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
                )

    for _id, down_pos, up_pos in zipped_pos:

        if tag_dict[_id + undercond + "climbing"][fnum]:
            write_on_frame("Climbing", down_pos)
        if (
            tag_dict[_id + undercond + "huddle"][fnum]
            and not tag_dict[_id + undercond + "climbing"][fnum]
        ):
            write_on_frame("huddle", down_pos)

        # Define the condition controlling the colour of the speed display
        if len(animal_ids) > 1:
            colcond = frame_speeds[_id] == max(list(frame_speeds.values()))
        else:
            colcond = hparams["huddle_speed"] > frame_speeds

        write_on_frame(
665
            str(
666
667
668
                np.round(
                    (frame_speeds if len(animal_ids) == 1 else frame_speeds[_id]), 2
                )
669
670
            )
            + " mmpf",
lucas_miranda's avatar
lucas_miranda committed
671
672
673
674
675
            up_pos,
            conditional_col(cond=colcond),
        )


lucas_miranda's avatar
lucas_miranda committed
676
# noinspection PyProtectedMember,PyDefaultArgument
677
def rule_based_video(
lucas_miranda's avatar
lucas_miranda committed
678
679
680
681
682
683
684
685
686
    coordinates: Coordinates,
    tracks: List,
    videos: List,
    vid_index: int,
    tag_dict: pd.DataFrame,
    frame_limit: int = np.inf,
    recog_limit: int = 1,
    path: str = os.path.join("."),
    hparams: dict = {},
lucas_miranda's avatar
lucas_miranda committed
687
) -> True:
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
    """Renders a version of the input video with all rule-based taggings in place.

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
        - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
        - vid_index (int): index in videos of the experiment to annotate
        - fps (float): frames per second of the analysed video. Same as input by default
        - path (str): directory in which the experimental data is stored
        - frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
        - recog_limit (int): number of frames to use for arena recognition (1 by default)
        - hparams (dict): dictionary to overwrite the default values of the hyperparameters of the functions
        that the rule-based pose estimation utilizes. Values can be:
            - speed_pause (int): size of the rolling window to use when computing speeds
            - close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - follow_frames (int): number of frames during which the following trait is tracked
            - follow_tol (int): maximum distance between follower and followed's path during the last follow_frames,
            in order to report a detection
            - huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection
            - huddle_spine (int): maximum average distance between spine body parts to report a huddle detection
            - huddle_speed (int): maximum speed to report a huddle detection

    Returns:
        True

    """

lucas_miranda's avatar
lucas_miranda committed
716
    # DATA OBTENTION AND PREPARATION
lucas_miranda's avatar
lucas_miranda committed
717
    hparams = get_hparameters(hparams)
718
    animal_ids = coordinates._animal_ids
lucas_miranda's avatar
lucas_miranda committed
719
    undercond = "_" if len(animal_ids) > 1 else ""
720

721
    try:
722
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
723
724
    except IndexError:
        vid_name = tracks[vid_index]
725
726
727
728

    arena, h, w = deepof.utils.recognize_arena(
        videos, vid_index, path, recog_limit, coordinates._arena
    )
729
    corners = frame_corners(h, w)
730

lucas_miranda's avatar
lucas_miranda committed
731
732
733
734
    cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
    # Keep track of the frame number, to align with the tracking data
    fnum = 0
    writer = None
lucas_miranda's avatar
lucas_miranda committed
735
736
737
    frame_speeds = (
        {_id: -np.inf for _id in animal_ids} if len(animal_ids) > 1 else -np.inf
    )
738

lucas_miranda's avatar
lucas_miranda committed
739
740
    # Loop over the frames in the video
    while cap.isOpened() and fnum < frame_limit:
741

lucas_miranda's avatar
lucas_miranda committed
742
743
744
745
746
        ret, frame = cap.read()
        # if frame is read correctly ret is True
        if not ret:  # pragma: no cover
            print("Can't receive frame (stream end?). Exiting ...")
            break
747

lucas_miranda's avatar
lucas_miranda committed
748
        font = cv2.FONT_HERSHEY_COMPLEX_SMALL
749

lucas_miranda's avatar
lucas_miranda committed
750
751
752
        # Capture speeds
        try:
            if (
lucas_miranda's avatar
lucas_miranda committed
753
754
                list(frame_speeds.values())[0] == -np.inf
                or fnum % hparams["speed_pause"] == 0
lucas_miranda's avatar
lucas_miranda committed
755
756
            ):
                for _id in animal_ids:
757
                    frame_speeds[_id] = tag_dict[_id + undercond + "speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
758
759
        except AttributeError:
            if frame_speeds == -np.inf or fnum % hparams["speed_pause"] == 0:
760
                frame_speeds = tag_dict["speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
761
762

        # Display all annotations in the output video
lucas_miranda's avatar
lucas_miranda committed
763
764
765
766
        tag_rulebased_frames(
            frame,
            font,
            frame_speeds,
lucas_miranda's avatar
lucas_miranda committed
767
            animal_ids,
lucas_miranda's avatar
lucas_miranda committed
768
769
770
            corners,
            tag_dict,
            fnum,
lucas_miranda's avatar
lucas_miranda committed
771
            (w, h),
lucas_miranda's avatar
lucas_miranda committed
772
773
            undercond,
            hparams,
lucas_miranda's avatar
lucas_miranda committed
774
775
        )

lucas_miranda's avatar
lucas_miranda committed
776
777
778
779
780
        if writer is None:
            # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
            # Define the FPS. Also frame size is passed.
            writer = cv2.VideoWriter()
            writer.open(
781
                vid_name + "_tagged.avi",
lucas_miranda's avatar
lucas_miranda committed
782
783
784
785
786
                cv2.VideoWriter_fourcc(*"MJPG"),
                hparams["fps"],
                (frame.shape[1], frame.shape[0]),
                True,
            )
787

lucas_miranda's avatar
lucas_miranda committed
788
        writer.write(frame)
lucas_miranda's avatar
lucas_miranda committed
789
        fnum += 1
790

lucas_miranda's avatar
lucas_miranda committed
791
792
    cap.release()
    cv2.destroyAllWindows()
lucas_miranda's avatar
lucas_miranda committed
793
794

    return True