pose_utils.py 36.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# @author lucasmiranda42
# encoding: utf-8
# module deepof

"""

Functions and general utilities for rule-based pose estimation. See documentation for details

"""

11
12
13
from itertools import combinations
from scipy import stats
from typing import Any, List, NewType
14
import cv2
lucas_miranda's avatar
lucas_miranda committed
15
import deepof.utils
16
17
18
19
20
21
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import regex as re
import seaborn as sns
lucas_miranda's avatar
lucas_miranda committed
22
import warnings
23

lucas_miranda's avatar
lucas_miranda committed
24
25
26
27
# Ignore warning with no downstream effect
warnings.filterwarnings("ignore", message="All-NaN slice encountered")

# Create custom string type
28
29
30
31
Coordinates = NewType("Coordinates", Any)


def close_single_contact(
lucas_miranda's avatar
lucas_miranda committed
32
33
34
35
36
37
    pos_dframe: pd.DataFrame,
    left: str,
    right: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
38
39
40
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

41
42
43
44
45
46
47
48
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left (string): First member of the potential contact
        - right (string): Second member of the potential contact
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
49

50
51
52
    Returns:
        - contact_array (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    close_contact = None

    if type(right) == str:
        close_contact = (
            np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
        ) / arena_rel < tol

    elif type(right) == list:
        close_contact = np.any(
            [
                (np.linalg.norm(pos_dframe[left] - pos_dframe[r], axis=1) * arena_abs)
                / arena_rel
                < tol
                for r in right
            ],
            axis=0,
        )
71
72
73
74
75

    return close_contact


def close_double_contact(
lucas_miranda's avatar
lucas_miranda committed
76
77
78
79
80
81
82
83
84
    pos_dframe: pd.DataFrame,
    left1: str,
    left2: str,
    right1: str,
    right2: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
    rev: bool = False,
85
86
87
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

88
89
90
91
92
93
94
95
96
97
98
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left1 (string): First contact point of animal 1
        - left2 (string): Second contact point of animal 1
        - right1 (string): First contact point of animal 2
        - right2 (string): Second contact point of animal 2
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
        - rev (bool): reverses the default behaviour (nose2tail contact for both mice)
99

100
101
102
    Returns:
        - double_contact (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
103
104
105

    if rev:
        double_contact = (
lucas_miranda's avatar
lucas_miranda committed
106
107
108
109
110
111
112
113
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
114
115
116

    else:
        double_contact = (
lucas_miranda's avatar
lucas_miranda committed
117
118
119
120
121
122
123
124
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
125
126
127
128

    return double_contact


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def rotate(origin, point, ang):
    """Auxiliar function to climb_wall and sniff_object. Rotates x,y coordinates over a pivot"""

    ox, oy = origin
    px, py = point

    qx = ox + np.cos(ang) * (px - ox) - np.sin(ang) * (py - oy)
    qy = oy + np.sin(ang) * (px - ox) + np.cos(ang) * (py - oy)
    return qx, qy


def outside_ellipse(x, y, e_center, e_axes, e_angle, threshold=0.0):
    """Auxiliar function to climb_wall and sniff_object. Returns True if the passed x, y coordinates
    are outside the ellipse denoted by e_center, e_axes and e_angle, with a certain threshold"""

    x, y = rotate(e_center, (x, y), np.radians(e_angle))

    term_x = (x - e_center[0]) ** 2 / (e_axes[0] + threshold) ** 2
    term_y = (y - e_center[1]) ** 2 / (e_axes[1] + threshold) ** 2
    return term_x + term_y > 1


151
def climb_wall(
152
153
154
155
156
    arena_type: str,
    arena: np.array,
    pos_dict: pd.DataFrame,
    tol: float,
    nose: str,
lucas_miranda's avatar
lucas_miranda committed
157
    centered_data: bool = False,
158
159
160
) -> np.array:
    """Returns True if the specified mouse is climbing the wall

161
162
163
164
165
166
167
168
169
    Parameters:
        - arena_type (str): arena type; must be one of ['circular']
        - arena (np.array): contains arena location and shape details
        - pos_dict (table_dict): position over time for all videos in a project
        - tol (float): minimum tolerance to report a hit
        - nose (str): indicates the name of the body part representing the nose of
        the selected animal
        - arena_dims (int): indicates radius of the real arena in mm
        - centered_data (bool): indicates whether the input data is centered
170

171
172
173
    Returns:
        - climbing (np.array): boolean array. True if selected animal
        is climbing the walls of the arena"""
174
175
176
177

    nose = pos_dict[nose]

    if arena_type == "circular":
178
179
180
181
182
183
184
185
186
187
188
        center = np.zeros(2) if centered_data else np.array(arena[0])
        axes = arena[1]
        angle = arena[2]
        climbing = outside_ellipse(
            x=nose["x"],
            y=nose["y"],
            e_center=center,
            e_axes=axes,
            e_angle=-angle,
            threshold=tol,
        )
189
190
191
192
193
194
195

    else:
        raise NotImplementedError("Supported values for arena_type are ['circular']")

    return climbing


196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def sniff_object(
    speed_dframe: pd.DataFrame,
    arena_type: str,
    arena: np.array,
    pos_dict: pd.DataFrame,
    tol: float,
    tol_speed: float,
    nose: str,
    centered_data: bool = False,
    object: str = "arena",
    animal_id: str = "",
):
    """Returns True if the specified mouse is sniffing an object

    Parameters:
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - arena_type (str): arena type; must be one of ['circular']
        - arena (np.array): contains arena location and shape details
        - pos_dict (table_dict): position over time for all videos in a project
        - tol (float): minimum tolerance to report a hit
        - nose (str): indicates the name of the body part representing the nose of
        the selected animal
        - arena_dims (int): indicates radius of the real arena in mm
        - centered_data (bool): indicates whether the input data is centered
        - object (str): indicates the object that the animal is sniffing.
        Can be one of ['arena', 'partner']

    Returns:
        - sniffing (np.array): boolean array. True if selected animal
        is sniffing the selected object"""

    nose, nosing = pos_dict[nose], True

229
230
231
    if animal_id != "":
        animal_id += "_"

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    if object == "arena":
        if arena_type == "circular":
            center = np.zeros(2) if centered_data else np.array(arena[0])
            axes = arena[1]
            angle = arena[2]

            nosing_min = outside_ellipse(
                x=nose["x"],
                y=nose["y"],
                e_center=center,
                e_axes=axes,
                e_angle=-angle,
                threshold=-tol,
            )
            nosing_max = outside_ellipse(
                x=nose["x"],
                y=nose["y"],
                e_center=center,
                e_axes=axes,
                e_angle=-angle,
                threshold=tol,
            )
            nosing = nosing_min & (~nosing_max)

    elif object == "partner":
        raise NotImplementedError

    else:
        raise ValueError("object should be one of [arena, partner]")

    speed = speed_dframe[animal_id + "Center"] < tol_speed
    sniffing = nosing & speed

    return sniffing


268
def huddle(
lucas_miranda's avatar
lucas_miranda committed
269
270
271
272
273
    pos_dframe: pd.DataFrame,
    speed_dframe: pd.DataFrame,
    tol_forward: float,
    tol_speed: float,
    animal_id: str = "",
274
) -> np.array:
275
    """Returns true when the mouse is huddling using simple rules.
276

277
278
279
280
281
282
283
284
285
286
287
    Parameters:
        - pos_dframe (pandas.DataFrame): position of body parts over time
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - tol_forward (float): Maximum tolerated distance between ears and
        forward limbs
        - tol_rear (float): Maximum tolerated average distance between spine
        body parts
        - tol_speed (float): Maximum tolerated speed for the center of the mouse

    Returns:
        hudd (np.array): True if the animal is huddling, False otherwise
288
    """
289
290
291
292
293

    if animal_id != "":
        animal_id += "_"

    forward = (
lucas_miranda's avatar
lucas_miranda committed
294
        np.linalg.norm(
lucas_miranda's avatar
lucas_miranda committed
295
            pos_dframe[animal_id + "Left_bhip"] - pos_dframe[animal_id + "Left_fhip"],
lucas_miranda's avatar
lucas_miranda committed
296
297
298
            axis=1,
        )
        < tol_forward
lucas_miranda's avatar
lucas_miranda committed
299
    ) | (
lucas_miranda's avatar
lucas_miranda committed
300
        np.linalg.norm(
lucas_miranda's avatar
lucas_miranda committed
301
            pos_dframe[animal_id + "Right_bhip"] - pos_dframe[animal_id + "Right_fhip"],
lucas_miranda's avatar
lucas_miranda committed
302
303
304
305
            axis=1,
        )
        < tol_forward
    )
306
307

    speed = speed_dframe[animal_id + "Center"] < tol_speed
lucas_miranda's avatar
lucas_miranda committed
308
    hudd = forward & speed
309
310
311
312

    return hudd


313
314
315
316
317
318
319
def dig(
    speed_dframe: pd.DataFrame,
    likelihood_dframe: pd.DataFrame,
    tol_speed: float,
    tol_likelihood: float,
    animal_id: str = "",
):
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    """Returns true when the mouse is digging using simple rules.

    Parameters:
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - likelihood_dframe (pandas.DataFrame): likelihood of body part tracker over time,
        as directly obtained from DeepLabCut
        - tol_speed (float): Maximum tolerated speed for the center of the mouse
        - tol_likelihood (float): Maximum tolerated likelihood for the nose (if the animal
        is digging, the nose is momentarily occluded).

    Returns:
        dig (np.array): True if the animal is digging, False otherwise
    """

    if animal_id != "":
        animal_id += "_"

    speed = speed_dframe[animal_id + "Center"] < tol_speed
338
339
    nose_likelihood = likelihood_dframe[animal_id + "Nose"] < tol_likelihood
    digging = speed & nose_likelihood
340

341
    return digging
342
343


344
def look_around(
345
346
347
348
349
350
    speed_dframe: pd.DataFrame,
    likelihood_dframe: pd.DataFrame,
    tol_speed: float,
    tol_likelihood: float,
    animal_id: str = "",
):
351
    """Returns true when the mouse is digging using simple rules.
352

353
354
355
356
357
358
359
    Parameters:
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - likelihood_dframe (pandas.DataFrame): likelihood of body part tracker over time,
        as directly obtained from DeepLabCut
        - tol_speed (float): Maximum tolerated speed for the center of the mouse
        - tol_likelihood (float): Maximum tolerated likelihood for the nose (if the animal
        is digging, the nose is momentarily occluded).
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
    Returns:
        lookaround (np.array): True if the animal is standing still and looking around, False otherwise
    """

    if animal_id != "":
        animal_id += "_"

    speed = speed_dframe[animal_id + "Center"] < tol_speed
    nose_speed = speed_dframe[animal_id + "Center"] < speed_dframe[animal_id + "Nose"]
    nose_likelihood = likelihood_dframe[animal_id + "Nose"] > tol_likelihood

    lookaround = speed & nose_likelihood & nose_speed

    return lookaround
375
376


377
def following_path(
lucas_miranda's avatar
lucas_miranda committed
378
379
380
381
382
383
    distance_dframe: pd.DataFrame,
    position_dframe: pd.DataFrame,
    follower: str,
    followed: str,
    frames: int = 20,
    tol: float = 0,
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
) -> np.array:
    """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
    followed has walked over the last specified number of frames

        Parameters:
            - distance_dframe (pandas.DataFrame): distances between bodyparts; generated by the preprocess module
            - position_dframe (pandas.DataFrame): position of bodyparts; generated by the preprocess module
            - follower (str) identifier for the animal who's following
            - followed (str) identifier for the animal who's followed
            - frames (int) frames in which to track whether the process consistently occurs,
            - tol (float) Maximum distance for which True is returned

        Returns:
            - follow (np.array): boolean sequence, True if conditions are fulfilled, False otherwise"""

    # Check that follower is close enough to the path that followed has passed though in the last frames
    shift_dict = {
        i: position_dframe[followed + "_Tail_base"].shift(i) for i in range(frames)
    }
    dist_df = pd.DataFrame(
        {
            i: np.linalg.norm(
                position_dframe[follower + "_Nose"] - shift_dict[i], axis=1
            )
            for i in range(frames)
        }
    )

    # Check that the animals are oriented follower's nose -> followed's tail
    right_orient1 = (
lucas_miranda's avatar
lucas_miranda committed
414
415
416
417
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[
            tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
        ]
418
419
420
    )

    right_orient2 = (
lucas_miranda's avatar
lucas_miranda committed
421
422
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
423
424
425
    )

    follow = np.all(
426
427
        np.array([(dist_df.min(axis=1) < tol), right_orient1, right_orient2]),
        axis=0,
428
429
430
431
432
433
    )

    return follow


def single_behaviour_analysis(
lucas_miranda's avatar
lucas_miranda committed
434
435
436
437
438
439
440
    behaviour_name: str,
    treatment_dict: dict,
    behavioural_dict: dict,
    plot: int = 0,
    stat_tests: bool = True,
    save: str = None,
    ylim: float = None,
441
442
) -> list:
    """Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
443
    with the actual tags, outputs a box plot and a series of significance tests amongst the groups
444

445
446
447
448
449
450
451
452
     Parameters:
         - behaviour_name (str): name of the behavioural trait to analize
         - treatment_dict (dict): dictionary containing video names as keys and experimental conditions as values
         - behavioural_dict (dict): tagged dictionary containing video names as keys and annotations as values
         - plot (int): Silent if 0; otherwise, indicates the dpi of the figure to plot
         - stat_tests (bool): performs FDR corrected Mann-U non-parametric tests among all groups if True
         - save (str): Saves the produced figure to the specified file
         - ylim (float): y-limit for the boxplot. Ignored if plot == False
453

454
455
456
     Returns:
         - beh_dict (dict): dictionary containing experimental conditions as keys and video names as values
         - stat_dict (dict): dictionary containing condition pairs as keys and stat results as values"""
457
458
459
460
461

    beh_dict = {condition: [] for condition in treatment_dict.keys()}

    for condition in beh_dict.keys():
        for ind in treatment_dict[condition]:
462
463
            beh_dict[condition] += np.sum(behavioural_dict[ind][behaviour_name]) / len(
                behavioural_dict[ind][behaviour_name]
464
465
466
467
468
469
470
471
472
            )

    return_list = [beh_dict]

    if plot > 0:

        fig, ax = plt.subplots(dpi=plot)

        sns.boxplot(
473
474
475
476
            x=list(beh_dict.keys()),
            y=list(beh_dict.values()),
            orient="vertical",
            ax=ax,
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        )

        ax.set_title("{} across groups".format(behaviour_name))
        ax.set_ylabel("Proportion of frames")

        if ylim is not None:
            ax.set_ylim(ylim)

        if save is not None:  # pragma: no cover
            plt.savefig(save)

        return_list.append(fig)

    if stat_tests:
        stat_dict = {}
        for i in combinations(treatment_dict.keys(), 2):
            # Solves issue with automatically generated examples
lucas_miranda's avatar
lucas_miranda committed
494
495
496
497
498
499
500
501
            if np.any(
                np.array(
                    [
                        beh_dict[i[0]] == beh_dict[i[1]],
                        np.var(beh_dict[i[0]]) == 0,
                        np.var(beh_dict[i[1]]) == 0,
                    ]
                )
502
503
504
505
506
507
508
509
510
511
512
513
            ):
                stat_dict[i] = "Identical sources. Couldn't run"
            else:
                stat_dict[i] = stats.mannwhitneyu(
                    beh_dict[i[0]], beh_dict[i[1]], alternative="two-sided"
                )
        return_list.append(stat_dict)

    return return_list


def max_behaviour(
lucas_miranda's avatar
lucas_miranda committed
514
    behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
515
516
517
) -> np.array:
    """Returns the most frequent behaviour in a window of window_size frames

518
519
520
521
522
523
    Parameters:
            - behaviour_dframe (pd.DataFrame): boolean matrix containing occurrence
            of tagged behaviours per frame in the video
            - window_size (int): size of the window to use when computing
            the maximum behaviour per time slot
            - stepped (bool): sliding windows don't overlap if True. False by default
524

525
526
527
    Returns:
        - max_array (np.array): string array with the most common behaviour per instance
        of the sliding window"""
528
529
530
531
532
533
534
535
536
537
538
539

    speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()]

    behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).astype("float")
    win_array = behaviour_dframe.rolling(window_size, center=True).sum()
    if stepped:
        win_array = win_array[::window_size]
    max_array = win_array[1:].idxmax(axis=1)

    return np.array(max_array)


540
# noinspection PyDefaultArgument
lucas_miranda's avatar
lucas_miranda committed
541
542
543
def get_hparameters(hparams: dict = {}) -> dict:
    """Returns the most frequent behaviour in a window of window_size frames

544
545
    Parameters:
        - hparams (dict): dictionary containing hyperparameters to overwrite
lucas_miranda's avatar
lucas_miranda committed
546

547
    Returns:
548
        - defaults (dict): dictionary with overwritten parameters. Those not
549
        specified in the input retain their default values"""
lucas_miranda's avatar
lucas_miranda committed
550
551

    defaults = {
lucas_miranda's avatar
lucas_miranda committed
552
        "speed_pause": 5,
553
        "climb_tol": 10,
554
555
556
557
        "close_contact_tol": 35,
        "side_contact_tol": 80,
        "follow_frames": 10,
        "follow_tol": 5,
lucas_miranda's avatar
lucas_miranda committed
558
        "huddle_forward": 15,
lucas_miranda's avatar
lucas_miranda committed
559
        "huddle_speed": 1,
560
        "nose_likelihood": 0.85,
lucas_miranda's avatar
lucas_miranda committed
561
        "fps": 24,
lucas_miranda's avatar
lucas_miranda committed
562
    }
563

lucas_miranda's avatar
lucas_miranda committed
564
565
    for k, v in hparams.items():
        defaults[k] = v
566

lucas_miranda's avatar
lucas_miranda committed
567
568
569
    return defaults


570
571
572
573
# noinspection PyDefaultArgument
def frame_corners(w, h, corners: dict = {}):
    """Returns a dictionary with the corner positions of the video frame

574
575
576
577
    Parameters:
        - w (int): width of the frame in pixels
        - h (int): height of the frame in pixels
        - corners (dict): dictionary containing corners to overwrite
578

579
580
581
    Returns:
        - defaults (dict): dictionary with overwriten parameters. Those not
        specified in the input retain their default values"""
582
583
584
585
586
587
588
589
590
591
592
593
594
595

    defaults = {
        "downleft": (int(w * 0.3 / 10), int(h / 1.05)),
        "downright": (int(w * 6.5 / 10), int(h / 1.05)),
        "upleft": (int(w * 0.3 / 10), int(h / 20)),
        "upright": (int(w * 6.3 / 10), int(h / 20)),
    }

    for k, v in corners.items():
        defaults[k] = v

    return defaults


596
# noinspection PyDefaultArgument,PyProtectedMember
597
def rule_based_tagging(
lucas_miranda's avatar
lucas_miranda committed
598
599
600
    tracks: List,
    videos: List,
    coordinates: Coordinates,
601
    coords: Any,
602
    dists: Any,
603
    speeds: Any,
lucas_miranda's avatar
lucas_miranda committed
604
    vid_index: int,
605
    arena_type: str,
lucas_miranda's avatar
lucas_miranda committed
606
607
    recog_limit: int = 1,
    path: str = os.path.join("."),
lucas_miranda's avatar
lucas_miranda committed
608
    params: dict = {},
609
610
611
612
613
614
615
616
) -> pd.DataFrame:
    """Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
    video displaying the information in real time

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
        - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
617
        - coords (deepof.preprocessing.table_dict): table_dict with already processed coordinates
618
        - dists (deepof.preprocessing.table_dict): table_dict with already processed distances
619
        - speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds
620
        - vid_index (int): index in videos of the experiment to annotate
lucas_miranda's avatar
lucas_miranda committed
621
        - path (str): directory in which the experimental data is stored
622
        - recog_limit (int): number of frames to use for arena recognition (1 by default)
623
624
        - params (dict): dictionary to overwrite the default values of the parameters of the functions
        that the rule-based pose estimation utilizes. See documentation for details.
625
626
627
628
629

    Returns:
        - tag_df (pandas.DataFrame): table with traits as columns and frames as rows. Each
        value is a boolean indicating trait detection at a given time"""

lucas_miranda's avatar
lucas_miranda committed
630
    params = get_hparameters(params)
631
    animal_ids = coordinates._animal_ids
632
    undercond = "_" if len(animal_ids) > 1 else ""
lucas_miranda's avatar
lucas_miranda committed
633

634
    try:
635
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
636
637
    except IndexError:
        vid_name = tracks[vid_index]
638

639
    coords = coords[vid_name]
640
    dists = dists[vid_name]
641
    speeds = speeds[vid_name]
642
    likelihoods = coordinates.get_quality()[vid_name]
643
644
    arena_abs = coordinates.get_arenas[1][0]
    arena, h, w = deepof.utils.recognize_arena(
645
        videos, vid_index, path, recog_limit, coordinates._arena
646
647
648
649
650
    )

    # Dictionary with motives per frame
    tag_dict = {}

651
652
653
654
655
656
657
658
659
660
661
662
663
    # Bulk body parts
    main_body = [
        "Left_ear",
        "Right_ear",
        "Spine_1",
        "Center",
        "Spine_2",
        "Left_fhip",
        "Right_fhip",
        "Left_bhip",
        "Right_bhip",
    ]

664
665
    def onebyone_contact(bparts: List):
        """Returns a smooth boolean array with 1to1 contacts between two mice"""
lucas_miranda's avatar
lucas_miranda committed
666
        nonlocal coords, animal_ids, params, arena_abs, arena
667

668
669
670
671
672
        try:
            right = animal_ids[1] + bparts[-1]
        except TypeError:
            right = [animal_ids[1] + "_" + suffix for suffix in bparts[-1]]

673
        return deepof.utils.smooth_boolean_array(
674
675
            close_single_contact(
                coords,
676
                animal_ids[0] + bparts[0],
677
                right,
lucas_miranda's avatar
lucas_miranda committed
678
                params["close_contact_tol"],
679
                arena_abs,
680
                arena[1][1],
681
682
            )
        )
683
684
685
686

    def twobytwo_contact(rev):
        """Returns a smooth boolean array with side by side contacts between two mice"""

lucas_miranda's avatar
lucas_miranda committed
687
        nonlocal coords, animal_ids, params, arena_abs, arena
688
        return deepof.utils.smooth_boolean_array(
689
690
691
692
693
694
            close_double_contact(
                coords,
                animal_ids[0] + "_Nose",
                animal_ids[0] + "_Tail_base",
                animal_ids[1] + "_Nose",
                animal_ids[1] + "_Tail_base",
lucas_miranda's avatar
lucas_miranda committed
695
                params["side_contact_tol"],
696
                rev=rev,
697
                arena_abs=arena_abs,
698
                arena_rel=arena[1][1],
699
700
            )
        )
701

702
    def overall_speed(ovr_speeds, _id, ucond):
lucas_miranda's avatar
lucas_miranda committed
703
704
705
706
707
708
709
710
711
712
713
714
715
        bparts = [
            "Center",
            "Spine_1",
            "Spine_2",
            "Nose",
            "Left_ear",
            "Right_ear",
            "Left_fhip",
            "Right_fhip",
            "Left_bhip",
            "Right_bhip",
            "Tail_base",
        ]
716
        array = ovr_speeds[[_id + ucond + bpart for bpart in bparts]]
lucas_miranda's avatar
lucas_miranda committed
717
718
719
        avg_speed = np.nanmedian(array[1:], axis=1)
        return np.insert(avg_speed, 0, np.nan, axis=0)

720
    if len(animal_ids) == 2:
721
722
723
724
725
726
727
        # Define behaviours that can be computed on the fly from the distance matrix
        tag_dict["nose2nose"] = onebyone_contact(bparts=["_Nose"])

        tag_dict["sidebyside"] = twobytwo_contact(rev=False)

        tag_dict["sidereside"] = twobytwo_contact(rev=True)

728
729
730
731
732
733
        tag_dict[animal_ids[0] + "_nose2tail"] = onebyone_contact(
            bparts=["_Nose", "_Tail_base"]
        )
        tag_dict[animal_ids[1] + "_nose2tail"] = onebyone_contact(
            bparts=["_Tail_base", "_Nose"]
        )
734
735
736
737
738
739
740
741
742
743
744
745
        tag_dict[animal_ids[0] + "_nose2body"] = onebyone_contact(
            bparts=[
                "_Nose",
                main_body,
            ]
        )
        tag_dict[animal_ids[1] + "_nose2body"] = onebyone_contact(
            bparts=[
                "_Nose",
                main_body,
            ]
        )
746

747
748
749
        for _id in animal_ids:
            tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
                following_path(
750
                    dists,
751
752
753
                    coords,
                    follower=_id,
                    followed=[i for i in animal_ids if i != _id][0],
lucas_miranda's avatar
lucas_miranda committed
754
755
                    frames=params["follow_frames"],
                    tol=params["follow_tol"],
756
757
758
                )
            )

759
760
    for _id in animal_ids:
        tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
761
762
763
764
            climb_wall(
                arena_type,
                arena,
                coords,
lucas_miranda's avatar
lucas_miranda committed
765
                params["climb_tol"],
766
767
                _id + undercond + "Nose",
            )
768
        )
769
770
771
772
773
774
775
776
777
778
779
780
781
        tag_dict[_id + undercond + "sniffing"] = deepof.utils.smooth_boolean_array(
            sniff_object(
                speeds,
                arena_type,
                arena,
                coords,
                params["climb_tol"],
                params["huddle_speed"],
                _id + undercond + "Nose",
                object="arena",
                animal_id=_id,
            )
        )
lucas_miranda's avatar
lucas_miranda committed
782
        tag_dict[_id + undercond + "speed"] = overall_speed(speeds, _id, undercond)
783
        tag_dict[_id + undercond + "huddle"] = deepof.utils.smooth_boolean_array(
lucas_miranda's avatar
lucas_miranda committed
784
785
786
            huddle(
                coords,
                speeds,
lucas_miranda's avatar
lucas_miranda committed
787
788
                params["huddle_forward"],
                params["huddle_speed"],
789
                animal_id=_id,
lucas_miranda's avatar
lucas_miranda committed
790
            )
791
        )
792
793
794
795
796
797
798
799
800
        tag_dict[_id + undercond + "dig"] = deepof.utils.smooth_boolean_array(
            dig(
                speeds,
                likelihoods,
                params["huddle_speed"],
                params["nose_likelihood"],
                animal_id=_id,
            )
        )
801
802
803
804
805
806
807
808
809
        tag_dict[_id + undercond + "lookaround"] = deepof.utils.smooth_boolean_array(
            look_around(
                speeds,
                likelihoods,
                params["huddle_speed"],
                params["nose_likelihood"],
                animal_id=_id,
            )
        )
810

811
812
813
814
815
    tag_df = pd.DataFrame(tag_dict)

    return tag_df


lucas_miranda's avatar
lucas_miranda committed
816
817
818
819
820
821
822
823
824
825
def tag_rulebased_frames(
    frame,
    font,
    frame_speeds,
    animal_ids,
    corners,
    tag_dict,
    fnum,
    undercond,
    hparams,
826
    arena,
827
    debug,
828
    coords,
lucas_miranda's avatar
lucas_miranda committed
829
):
830
    """Helper function for rule_based_video. Annotates a given frame with on-screen information
lucas_miranda's avatar
lucas_miranda committed
831
832
    about the recognised patterns"""

833
    arena, h, w = arena
lucas_miranda's avatar
lucas_miranda committed
834

lucas_miranda's avatar
lucas_miranda committed
835
836
    def write_on_frame(text, pos, col=(255, 255, 255)):
        """Partial closure over cv2.putText to avoid code repetition"""
837
        return cv2.putText(frame, text, pos, font, 0.75, col, 2)
lucas_miranda's avatar
lucas_miranda committed
838

839
840
841
842
843
844
845
    def conditional_flag():
        """Returns a tag depending on a condition"""
        if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
            return left_flag
        else:
            return right_flag

lucas_miranda's avatar
lucas_miranda committed
846
847
848
849
850
851
852
853
854
855
856
857
858
    def conditional_pos():
        """Returns a position depending on a condition"""
        if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
            return corners["downleft"]
        else:
            return corners["downright"]

    def conditional_col(cond=None):
        """Returns a colour depending on a condition"""
        if cond is None:
            cond = frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
        if cond:
            return 150, 255, 150
lucas_miranda's avatar
lucas_miranda committed
859
860
        else:
            return 150, 150, 255
lucas_miranda's avatar
lucas_miranda committed
861

862
863
864
865
866
    # Keep track of space usage in the output video
    # The flags are set to False as soon as the lower
    # corners are occupied with text
    left_flag, right_flag = True, True

867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    if debug:
        # Print arena for debugging
        cv2.ellipse(frame, arena[0], arena[1], arena[2], 0, 360, (0, 255, 0), 3)
        # Print body parts for debuging
        for bpart in coords.columns.levels[0]:
            if not np.isnan(coords[bpart]["x"][fnum]):
                cv2.circle(
                    frame,
                    (int(coords[bpart]["x"][fnum]), int(coords[bpart]["y"][fnum])),
                    radius=3,
                    color=(
                        (255, 0, 0) if bpart.startswith(animal_ids[0]) else (0, 0, 255)
                    ),
                    thickness=-1,
                )
        # Print frame number
        write_on_frame("Frame " + str(fnum), (int(w * 0.3 / 10), int(h / 1.15)))
884

885
    if len(animal_ids) > 1:
886

887
        if tag_dict["nose2nose"][fnum]:
lucas_miranda's avatar
lucas_miranda committed
888
            write_on_frame("Nose-Nose", conditional_pos())
889
890
891
892
893
894
895
896
897
898
899
900
901
902
            if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
                left_flag = False
            else:
                right_flag = False

        if tag_dict[animal_ids[0] + "_nose2body"][fnum] and left_flag:
            write_on_frame("nose2body", corners["downleft"])
            left_flag = False

        if tag_dict[animal_ids[1] + "_nose2body"][fnum] and right_flag:
            write_on_frame("nose2body", corners["downright"])
            right_flag = False

        if tag_dict[animal_ids[0] + "_nose2tail"][fnum] and left_flag:
lucas_miranda's avatar
lucas_miranda committed
903
            write_on_frame("Nose-Tail", corners["downleft"])
904
905
906
            left_flag = False

        if tag_dict[animal_ids[1] + "_nose2tail"][fnum] and right_flag:
lucas_miranda's avatar
lucas_miranda committed
907
            write_on_frame("Nose-Tail", corners["downright"])
908
909
910
            right_flag = False

        if tag_dict["sidebyside"][fnum] and left_flag and conditional_flag():
lucas_miranda's avatar
lucas_miranda committed
911
            write_on_frame(
912
913
                "Side-side",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
914
            )
915
916
917
918
919
920
            if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
                left_flag = False
            else:
                right_flag = False

        if tag_dict["sidereside"][fnum] and left_flag and conditional_flag():
lucas_miranda's avatar
lucas_miranda committed
921
            write_on_frame(
922
923
                "Side-Rside",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
924
            )
925
926
927
928
929
            if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
                left_flag = False
            else:
                right_flag = False

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    zipped_pos = list(
        zip(
            animal_ids,
            [corners["downleft"], corners["downright"]],
            [corners["upleft"], corners["upright"]],
            [left_flag, right_flag],
        )
    )

    for _id, down_pos, up_pos, flag in zipped_pos:

        if flag:

            if tag_dict[_id + undercond + "climbing"][fnum]:
                write_on_frame("climbing", down_pos)
            elif tag_dict[_id + undercond + "huddle"][fnum]:
                write_on_frame("huddling", down_pos)
            elif tag_dict[_id + undercond + "sniffing"][fnum]:
                write_on_frame("sniffing", down_pos)
            elif tag_dict[_id + undercond + "dig"][fnum]:
                write_on_frame("digging", down_pos)
            elif tag_dict[_id + undercond + "lookaround"][fnum]:
                write_on_frame("lookaround", down_pos)

954
955
956
957
958
959
960
961
962
        #     if (
        #         tag_dict[_id + "_following"][fnum]
        #         and not tag_dict[_id + "_climbing"][fnum]
        #     ):
        #         write_on_frame(
        #             "*f",
        #             (int(w * 0.3 / 10), int(h / 10)),
        #             conditional_col(),
        #         )
lucas_miranda's avatar
lucas_miranda committed
963
964
965
966
967

        # Define the condition controlling the colour of the speed display
        if len(animal_ids) > 1:
            colcond = frame_speeds[_id] == max(list(frame_speeds.values()))
        else:
968
            colcond = hparams["huddle_speed"] < frame_speeds
lucas_miranda's avatar
lucas_miranda committed
969
970

        write_on_frame(
971
            str(
972
973
974
                np.round(
                    (frame_speeds if len(animal_ids) == 1 else frame_speeds[_id]), 2
                )
975
976
            )
            + " mmpf",
lucas_miranda's avatar
lucas_miranda committed
977
978
979
980
981
            up_pos,
            conditional_col(cond=colcond),
        )


lucas_miranda's avatar
lucas_miranda committed
982
# noinspection PyProtectedMember,PyDefaultArgument
983
def rule_based_video(
lucas_miranda's avatar
lucas_miranda committed
984
985
986
987
988
989
990
991
    coordinates: Coordinates,
    tracks: List,
    videos: List,
    vid_index: int,
    tag_dict: pd.DataFrame,
    frame_limit: int = np.inf,
    recog_limit: int = 1,
    path: str = os.path.join("."),
lucas_miranda's avatar
lucas_miranda committed
992
    params: dict = {},
993
    debug: bool = False,
lucas_miranda's avatar
lucas_miranda committed
994
) -> True:
995
996
997
998
999
1000
    """Renders a version of the input video with all rule-based taggings in place.

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
        - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
1001
1002
        - debug (bool): if True, several debugging attributes (such as used body parts and arena) are plotted in
        the output video
1003
1004
1005
1006
1007
        - vid_index (int): index in videos of the experiment to annotate
        - fps (float): frames per second of the analysed video. Same as input by default
        - path (str): directory in which the experimental data is stored
        - frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
        - recog_limit (int): number of frames to use for arena recognition (1 by default)
lucas_miranda's avatar
lucas_miranda committed
1008
        - params (dict): dictionary to overwrite the default values of the hyperparameters of the functions
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        that the rule-based pose estimation utilizes. Values can be:
            - speed_pause (int): size of the rolling window to use when computing speeds
            - close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - follow_frames (int): number of frames during which the following trait is tracked
            - follow_tol (int): maximum distance between follower and followed's path during the last follow_frames,
            in order to report a detection
            - huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection
            - huddle_speed (int): maximum speed to report a huddle detection

    Returns:
        True

    """

lucas_miranda's avatar
lucas_miranda committed
1024
    # DATA OBTENTION AND PREPARATION
lucas_miranda's avatar
lucas_miranda committed
1025
    params = get_hparameters(params)
1026
    animal_ids = coordinates._animal_ids
lucas_miranda's avatar
lucas_miranda committed
1027
    undercond = "_" if len(animal_ids) > 1 else ""
1028

1029
    try:
1030
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
1031
1032
    except IndexError:
        vid_name = tracks[vid_index]
1033
1034
1035
1036

    arena, h, w = deepof.utils.recognize_arena(
        videos, vid_index, path, recog_limit, coordinates._arena
    )
1037
    corners = frame_corners(h, w)
1038

lucas_miranda's avatar
lucas_miranda committed
1039
1040
1041
1042
    cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
    # Keep track of the frame number, to align with the tracking data
    fnum = 0
    writer = None
lucas_miranda's avatar
lucas_miranda committed
1043
1044
1045
    frame_speeds = (
        {_id: -np.inf for _id in animal_ids} if len(animal_ids) > 1 else -np.inf
    )
1046

lucas_miranda's avatar
lucas_miranda committed
1047
1048
    # Loop over the frames in the video
    while cap.isOpened() and fnum < frame_limit:
1049

lucas_miranda's avatar
lucas_miranda committed
1050
1051
1052
1053
1054
        ret, frame = cap.read()
        # if frame is read correctly ret is True
        if not ret:  # pragma: no cover
            print("Can't receive frame (stream end?). Exiting ...")
            break
1055

1056
        font = cv2.FONT_HERSHEY_DUPLEX
1057

lucas_miranda's avatar
lucas_miranda committed
1058
1059
1060
        # Capture speeds
        try:
            if (
lucas_miranda's avatar
lucas_miranda committed
1061
                list(frame_speeds.values())[0] == -np.inf
lucas_miranda's avatar
lucas_miranda committed
1062
                or fnum % params["speed_pause"] == 0
lucas_miranda's avatar
lucas_miranda committed
1063
1064
            ):
                for _id in animal_ids:
1065
                    frame_speeds[_id] = tag_dict[_id + undercond + "speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
1066
        except AttributeError:
lucas_miranda's avatar
lucas_miranda committed
1067
            if frame_speeds == -np.inf or fnum % params["speed_pause"] == 0:
1068
                frame_speeds = tag_dict["speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
1069
1070

        # Display all annotations in the output video
lucas_miranda's avatar
lucas_miranda committed
1071
1072
1073
1074
        tag_rulebased_frames(
            frame,
            font,
            frame_speeds,
lucas_miranda's avatar
lucas_miranda committed
1075
            animal_ids,
lucas_miranda's avatar
lucas_miranda committed
1076
1077
1078
1079
            corners,
            tag_dict,
            fnum,
            undercond,
lucas_miranda's avatar
lucas_miranda committed
1080
            params,
1081
            (arena, h, w),
1082
1083
            debug,
            coordinates.get_coords(center=False)[vid_name],
lucas_miranda's avatar
lucas_miranda committed
1084
1085
        )

lucas_miranda's avatar
lucas_miranda committed
1086
1087
1088
1089
1090
        if writer is None:
            # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
            # Define the FPS. Also frame size is passed.
            writer = cv2.VideoWriter()
            writer.open(
1091
                vid_name + "_tagged.avi",
lucas_miranda's avatar
lucas_miranda committed
1092
                cv2.VideoWriter_fourcc(*"MJPG"),
lucas_miranda's avatar
lucas_miranda committed
1093
                params["fps"],
lucas_miranda's avatar
lucas_miranda committed
1094
1095
1096
                (frame.shape[1], frame.shape[0]),
                True,
            )
1097

lucas_miranda's avatar
lucas_miranda committed
1098
        writer.write(frame)
lucas_miranda's avatar
lucas_miranda committed
1099
        fnum += 1
1100

lucas_miranda's avatar
lucas_miranda committed
1101
1102
    cap.release()
    cv2.destroyAllWindows()
lucas_miranda's avatar
lucas_miranda committed
1103
1104

    return True
1105

1106

1107
# TODO:
1108
#    - Is border sniffing anything you might consider interesting?