pose_utils.py 36.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# @author lucasmiranda42
# encoding: utf-8
# module deepof

"""

Functions and general utilities for rule-based pose estimation. See documentation for details

"""

lucas_miranda's avatar
lucas_miranda committed
11
12
import os
import warnings
13
14
from itertools import combinations
from typing import Any, List, NewType
lucas_miranda's avatar
lucas_miranda committed
15

16
17
18
19
20
21
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import regex as re
import seaborn as sns
22
import tensorflow as tf
lucas_miranda's avatar
lucas_miranda committed
23
24
25
from scipy import stats

import deepof.utils
26

lucas_miranda's avatar
lucas_miranda committed
27
28
29
30
# Ignore warning with no downstream effect
warnings.filterwarnings("ignore", message="All-NaN slice encountered")

# Create custom string type
31
Coordinates = NewType("Coordinates", Any)
32
33
34


def close_single_contact(
35
36
37
38
39
40
    pos_dframe: pd.DataFrame,
    left: str,
    right: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
41
42
43
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

44
45
46
47
48
49
50
51
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left (string): First member of the potential contact
        - right (string): Second member of the potential contact
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
52

53
54
55
    Returns:
        - contact_array (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
56

57
58
    close_contact = None

59
    if isinstance(right, str):
60
        close_contact = (
61
62
            np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
        ) / arena_rel < tol
63

64
    elif isinstance(right, list):
65
66
67
68
69
70
71
72
73
        close_contact = np.any(
            [
                (np.linalg.norm(pos_dframe[left] - pos_dframe[r], axis=1) * arena_abs)
                / arena_rel
                < tol
                for r in right
            ],
            axis=0,
        )
74
75
76
77
78

    return close_contact


def close_double_contact(
79
80
81
82
83
84
85
86
87
    pos_dframe: pd.DataFrame,
    left1: str,
    left2: str,
    right1: str,
    right2: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
    rev: bool = False,
88
89
90
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

91
92
93
94
95
96
97
98
99
100
101
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left1 (string): First contact point of animal 1
        - left2 (string): Second contact point of animal 1
        - right1 (string): First contact point of animal 2
        - right2 (string): Second contact point of animal 2
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
        - rev (bool): reverses the default behaviour (nose2tail contact for both mice)
102

103
104
105
    Returns:
        - double_contact (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
106
107
108

    if rev:
        double_contact = (
109
110
111
112
113
114
115
116
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
117
118
119

    else:
        double_contact = (
120
121
122
123
124
125
126
127
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
128
129
130
131

    return double_contact


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def rotate(origin, point, ang):
    """Auxiliar function to climb_wall and sniff_object. Rotates x,y coordinates over a pivot"""

    ox, oy = origin
    px, py = point

    qx = ox + np.cos(ang) * (px - ox) - np.sin(ang) * (py - oy)
    qy = oy + np.sin(ang) * (px - ox) + np.cos(ang) * (py - oy)
    return qx, qy


def outside_ellipse(x, y, e_center, e_axes, e_angle, threshold=0.0):
    """Auxiliar function to climb_wall and sniff_object. Returns True if the passed x, y coordinates
    are outside the ellipse denoted by e_center, e_axes and e_angle, with a certain threshold"""

    x, y = rotate(e_center, (x, y), np.radians(e_angle))

    term_x = (x - e_center[0]) ** 2 / (e_axes[0] + threshold) ** 2
    term_y = (y - e_center[1]) ** 2 / (e_axes[1] + threshold) ** 2
    return term_x + term_y > 1


154
def climb_wall(
155
156
157
158
159
160
    arena_type: str,
    arena: np.array,
    pos_dict: pd.DataFrame,
    tol: float,
    nose: str,
    centered_data: bool = False,
161
162
163
) -> np.array:
    """Returns True if the specified mouse is climbing the wall

164
165
166
167
168
169
170
171
172
    Parameters:
        - arena_type (str): arena type; must be one of ['circular']
        - arena (np.array): contains arena location and shape details
        - pos_dict (table_dict): position over time for all videos in a project
        - tol (float): minimum tolerance to report a hit
        - nose (str): indicates the name of the body part representing the nose of
        the selected animal
        - arena_dims (int): indicates radius of the real arena in mm
        - centered_data (bool): indicates whether the input data is centered
173

174
175
176
    Returns:
        - climbing (np.array): boolean array. True if selected animal
        is climbing the walls of the arena"""
177
178
179
180

    nose = pos_dict[nose]

    if arena_type == "circular":
181
182
183
184
185
186
187
188
189
190
191
        center = np.zeros(2) if centered_data else np.array(arena[0])
        axes = arena[1]
        angle = arena[2]
        climbing = outside_ellipse(
            x=nose["x"],
            y=nose["y"],
            e_center=center,
            e_axes=axes,
            e_angle=-angle,
            threshold=tol,
        )
192
193
194
195
196
197
198

    else:
        raise NotImplementedError("Supported values for arena_type are ['circular']")

    return climbing


199
def sniff_object(
200
201
202
203
204
205
206
207
208
209
    speed_dframe: pd.DataFrame,
    arena_type: str,
    arena: np.array,
    pos_dict: pd.DataFrame,
    tol: float,
    tol_speed: float,
    nose: str,
    centered_data: bool = False,
    object: str = "arena",
    animal_id: str = "",
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
):
    """Returns True if the specified mouse is sniffing an object

    Parameters:
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - arena_type (str): arena type; must be one of ['circular']
        - arena (np.array): contains arena location and shape details
        - pos_dict (table_dict): position over time for all videos in a project
        - tol (float): minimum tolerance to report a hit
        - nose (str): indicates the name of the body part representing the nose of
        the selected animal
        - arena_dims (int): indicates radius of the real arena in mm
        - centered_data (bool): indicates whether the input data is centered
        - object (str): indicates the object that the animal is sniffing.
        Can be one of ['arena', 'partner']

    Returns:
        - sniffing (np.array): boolean array. True if selected animal
        is sniffing the selected object"""

    nose, nosing = pos_dict[nose], True

232
233
234
    if animal_id != "":
        animal_id += "_"

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    if object == "arena":
        if arena_type == "circular":
            center = np.zeros(2) if centered_data else np.array(arena[0])
            axes = arena[1]
            angle = arena[2]

            nosing_min = outside_ellipse(
                x=nose["x"],
                y=nose["y"],
                e_center=center,
                e_axes=axes,
                e_angle=-angle,
                threshold=-tol,
            )
            nosing_max = outside_ellipse(
                x=nose["x"],
                y=nose["y"],
                e_center=center,
                e_axes=axes,
                e_angle=-angle,
                threshold=tol,
            )
            nosing = nosing_min & (~nosing_max)

    elif object == "partner":
        raise NotImplementedError

    else:
        raise ValueError("object should be one of [arena, partner]")

    speed = speed_dframe[animal_id + "Center"] < tol_speed
    sniffing = nosing & speed

    return sniffing


271
def huddle(
272
273
274
275
276
    pos_dframe: pd.DataFrame,
    speed_dframe: pd.DataFrame,
    tol_forward: float,
    tol_speed: float,
    animal_id: str = "",
277
) -> np.array:
278
    """Returns true when the mouse is huddling using simple rules.
279

280
281
282
283
284
285
286
287
288
289
290
    Parameters:
        - pos_dframe (pandas.DataFrame): position of body parts over time
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - tol_forward (float): Maximum tolerated distance between ears and
        forward limbs
        - tol_rear (float): Maximum tolerated average distance between spine
        body parts
        - tol_speed (float): Maximum tolerated speed for the center of the mouse

    Returns:
        hudd (np.array): True if the animal is huddling, False otherwise
291
    """
292
293
294
295
296

    if animal_id != "":
        animal_id += "_"

    forward = (
297
298
299
300
301
302
303
304
305
306
307
308
        np.linalg.norm(
            pos_dframe[animal_id + "Left_bhip"] - pos_dframe[animal_id + "Left_fhip"],
            axis=1,
        )
        < tol_forward
    ) | (
        np.linalg.norm(
            pos_dframe[animal_id + "Right_bhip"] - pos_dframe[animal_id + "Right_fhip"],
            axis=1,
        )
        < tol_forward
    )
309
310

    speed = speed_dframe[animal_id + "Center"] < tol_speed
lucas_miranda's avatar
lucas_miranda committed
311
    hudd = forward & speed
312
313
314
315

    return hudd


316
def dig(
317
318
319
320
321
    speed_dframe: pd.DataFrame,
    likelihood_dframe: pd.DataFrame,
    tol_speed: float,
    tol_likelihood: float,
    animal_id: str = "",
322
):
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    """Returns true when the mouse is digging using simple rules.

    Parameters:
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - likelihood_dframe (pandas.DataFrame): likelihood of body part tracker over time,
        as directly obtained from DeepLabCut
        - tol_speed (float): Maximum tolerated speed for the center of the mouse
        - tol_likelihood (float): Maximum tolerated likelihood for the nose (if the animal
        is digging, the nose is momentarily occluded).

    Returns:
        dig (np.array): True if the animal is digging, False otherwise
    """

    if animal_id != "":
        animal_id += "_"

    speed = speed_dframe[animal_id + "Center"] < tol_speed
341
342
    nose_likelihood = likelihood_dframe[animal_id + "Nose"] < tol_likelihood
    digging = speed & nose_likelihood
343

344
    return digging
345
346


347
def look_around(
348
349
350
351
352
    speed_dframe: pd.DataFrame,
    likelihood_dframe: pd.DataFrame,
    tol_speed: float,
    tol_likelihood: float,
    animal_id: str = "",
353
):
354
    """Returns true when the mouse is digging using simple rules.
355

356
357
358
359
360
361
362
    Parameters:
        - speed_dframe (pandas.DataFrame): speed of body parts over time
        - likelihood_dframe (pandas.DataFrame): likelihood of body part tracker over time,
        as directly obtained from DeepLabCut
        - tol_speed (float): Maximum tolerated speed for the center of the mouse
        - tol_likelihood (float): Maximum tolerated likelihood for the nose (if the animal
        is digging, the nose is momentarily occluded).
363

364
365
366
367
368
369
370
371
372
373
374
375
376
377
    Returns:
        lookaround (np.array): True if the animal is standing still and looking around, False otherwise
    """

    if animal_id != "":
        animal_id += "_"

    speed = speed_dframe[animal_id + "Center"] < tol_speed
    nose_speed = speed_dframe[animal_id + "Center"] < speed_dframe[animal_id + "Nose"]
    nose_likelihood = likelihood_dframe[animal_id + "Nose"] > tol_likelihood

    lookaround = speed & nose_likelihood & nose_speed

    return lookaround
378
379


380
def following_path(
381
382
383
384
385
386
    distance_dframe: pd.DataFrame,
    position_dframe: pd.DataFrame,
    follower: str,
    followed: str,
    frames: int = 20,
    tol: float = 0,
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
) -> np.array:
    """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
    followed has walked over the last specified number of frames

        Parameters:
            - distance_dframe (pandas.DataFrame): distances between bodyparts; generated by the preprocess module
            - position_dframe (pandas.DataFrame): position of bodyparts; generated by the preprocess module
            - follower (str) identifier for the animal who's following
            - followed (str) identifier for the animal who's followed
            - frames (int) frames in which to track whether the process consistently occurs,
            - tol (float) Maximum distance for which True is returned

        Returns:
            - follow (np.array): boolean sequence, True if conditions are fulfilled, False otherwise"""

    # Check that follower is close enough to the path that followed has passed though in the last frames
    shift_dict = {
        i: position_dframe[followed + "_Tail_base"].shift(i) for i in range(frames)
    }
    dist_df = pd.DataFrame(
        {
            i: np.linalg.norm(
                position_dframe[follower + "_Nose"] - shift_dict[i], axis=1
            )
            for i in range(frames)
        }
    )

    # Check that the animals are oriented follower's nose -> followed's tail
    right_orient1 = (
417
418
419
420
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[
            tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
        ]
421
422
423
    )

    right_orient2 = (
424
425
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
426
427
428
    )

    follow = np.all(
429
430
        np.array([(dist_df.min(axis=1) < tol), right_orient1, right_orient2]),
        axis=0,
431
432
433
434
435
436
    )

    return follow


def single_behaviour_analysis(
437
438
439
440
441
442
443
    behaviour_name: str,
    treatment_dict: dict,
    behavioural_dict: dict,
    plot: int = 0,
    stat_tests: bool = True,
    save: str = None,
    ylim: float = None,
444
445
) -> list:
    """Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
446
    with the actual tags, outputs a box plot and a series of significance tests amongst the groups
447

448
449
450
451
452
453
454
455
     Parameters:
         - behaviour_name (str): name of the behavioural trait to analize
         - treatment_dict (dict): dictionary containing video names as keys and experimental conditions as values
         - behavioural_dict (dict): tagged dictionary containing video names as keys and annotations as values
         - plot (int): Silent if 0; otherwise, indicates the dpi of the figure to plot
         - stat_tests (bool): performs FDR corrected Mann-U non-parametric tests among all groups if True
         - save (str): Saves the produced figure to the specified file
         - ylim (float): y-limit for the boxplot. Ignored if plot == False
456

457
458
459
     Returns:
         - beh_dict (dict): dictionary containing experimental conditions as keys and video names as values
         - stat_dict (dict): dictionary containing condition pairs as keys and stat results as values"""
460
461
462
463
464

    beh_dict = {condition: [] for condition in treatment_dict.keys()}

    for condition in beh_dict.keys():
        for ind in treatment_dict[condition]:
465
466
            beh_dict[condition] += np.sum(behavioural_dict[ind][behaviour_name]) / len(
                behavioural_dict[ind][behaviour_name]
467
468
469
470
471
472
473
474
475
            )

    return_list = [beh_dict]

    if plot > 0:

        fig, ax = plt.subplots(dpi=plot)

        sns.boxplot(
476
477
478
479
            x=list(beh_dict.keys()),
            y=list(beh_dict.values()),
            orient="vertical",
            ax=ax,
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        )

        ax.set_title("{} across groups".format(behaviour_name))
        ax.set_ylabel("Proportion of frames")

        if ylim is not None:
            ax.set_ylim(ylim)

        if save is not None:  # pragma: no cover
            plt.savefig(save)

        return_list.append(fig)

    if stat_tests:
        stat_dict = {}
        for i in combinations(treatment_dict.keys(), 2):
            # Solves issue with automatically generated examples
lucas_miranda's avatar
lucas_miranda committed
497
            if np.any(
498
499
500
501
502
503
504
                np.array(
                    [
                        beh_dict[i[0]] == beh_dict[i[1]],
                        np.var(beh_dict[i[0]]) == 0,
                        np.var(beh_dict[i[1]]) == 0,
                    ]
                )
505
506
507
508
509
510
511
512
513
514
515
516
            ):
                stat_dict[i] = "Identical sources. Couldn't run"
            else:
                stat_dict[i] = stats.mannwhitneyu(
                    beh_dict[i[0]], beh_dict[i[1]], alternative="two-sided"
                )
        return_list.append(stat_dict)

    return return_list


def max_behaviour(
517
    behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
518
519
520
) -> np.array:
    """Returns the most frequent behaviour in a window of window_size frames

521
522
523
524
525
526
    Parameters:
            - behaviour_dframe (pd.DataFrame): boolean matrix containing occurrence
            of tagged behaviours per frame in the video
            - window_size (int): size of the window to use when computing
            the maximum behaviour per time slot
            - stepped (bool): sliding windows don't overlap if True. False by default
527

528
529
530
    Returns:
        - max_array (np.array): string array with the most common behaviour per instance
        of the sliding window"""
531
532
533

    speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()]

534
    behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).astype("float")
535
536
537
538
539
540
541
542
    win_array = behaviour_dframe.rolling(window_size, center=True).sum()
    if stepped:
        win_array = win_array[::window_size]
    max_array = win_array[1:].idxmax(axis=1)

    return np.array(max_array)


543
# noinspection PyDefaultArgument
lucas_miranda's avatar
lucas_miranda committed
544
545
546
def get_hparameters(hparams: dict = {}) -> dict:
    """Returns the most frequent behaviour in a window of window_size frames

547
548
    Parameters:
        - hparams (dict): dictionary containing hyperparameters to overwrite
lucas_miranda's avatar
lucas_miranda committed
549

550
    Returns:
551
        - defaults (dict): dictionary with overwritten parameters. Those not
552
        specified in the input retain their default values"""
lucas_miranda's avatar
lucas_miranda committed
553
554

    defaults = {
lucas_miranda's avatar
lucas_miranda committed
555
        "speed_pause": 5,
556
        "climb_tol": 10,
557
558
559
560
        "close_contact_tol": 35,
        "side_contact_tol": 80,
        "follow_frames": 10,
        "follow_tol": 5,
lucas_miranda's avatar
lucas_miranda committed
561
        "huddle_forward": 15,
lucas_miranda's avatar
lucas_miranda committed
562
        "huddle_speed": 1,
563
        "nose_likelihood": 0.85,
lucas_miranda's avatar
lucas_miranda committed
564
        "fps": 24,
lucas_miranda's avatar
lucas_miranda committed
565
    }
566

lucas_miranda's avatar
lucas_miranda committed
567
568
    for k, v in hparams.items():
        defaults[k] = v
569

lucas_miranda's avatar
lucas_miranda committed
570
571
572
    return defaults


573
574
575
576
# noinspection PyDefaultArgument
def frame_corners(w, h, corners: dict = {}):
    """Returns a dictionary with the corner positions of the video frame

577
578
579
580
    Parameters:
        - w (int): width of the frame in pixels
        - h (int): height of the frame in pixels
        - corners (dict): dictionary containing corners to overwrite
581

582
583
584
    Returns:
        - defaults (dict): dictionary with overwriten parameters. Those not
        specified in the input retain their default values"""
585
586
587
588
589
590
591
592
593
594
595
596
597
598

    defaults = {
        "downleft": (int(w * 0.3 / 10), int(h / 1.05)),
        "downright": (int(w * 6.5 / 10), int(h / 1.05)),
        "upleft": (int(w * 0.3 / 10), int(h / 20)),
        "upright": (int(w * 6.3 / 10), int(h / 20)),
    }

    for k, v in corners.items():
        defaults[k] = v

    return defaults


599
# noinspection PyDefaultArgument,PyProtectedMember
600
def rule_based_tagging(
601
602
603
604
    coordinates: Coordinates,
    coords: Any,
    dists: Any,
    speeds: Any,
605
    video: str,
606
    params: dict = {},
607
608
609
610
611
612
613
) -> pd.DataFrame:
    """Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
    video displaying the information in real time

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
614
615
616
617
        - coordinates (deepof.data.coordinates): coordinates object containing the project information
        - coords (deepof.data.table_dict): table_dict with already processed coordinates
        - dists (deepof.data.table_dict): table_dict with already processed distances
        - speeds (deepof.data.table_dict): table_dict with already processed speeds
618
        - vid_index (int): index in videos of the experiment to annotate
lucas_miranda's avatar
lucas_miranda committed
619
        - path (str): directory in which the experimental data is stored
620
        - recog_limit (int): number of frames to use for arena recognition (100 by default)
621
622
        - params (dict): dictionary to overwrite the default values of the parameters of the functions
        that the rule-based pose estimation utilizes. See documentation for details.
623
624
625
626
627

    Returns:
        - tag_df (pandas.DataFrame): table with traits as columns and frames as rows. Each
        value is a boolean indicating trait detection at a given time"""

628
629
630
631
632
633
634
    # Extract useful information from coordinates object
    tracks = list(coordinates._tables.keys())
    vid_index = coordinates._videos.index(video)

    arena_params = coordinates._arena_params[vid_index]
    arena_type = coordinates._arena

lucas_miranda's avatar
lucas_miranda committed
635
    params = get_hparameters(params)
636
    animal_ids = coordinates._animal_ids
637
    undercond = "_" if len(animal_ids) > 1 else ""
lucas_miranda's avatar
lucas_miranda committed
638

639
    try:
640
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
641
642
    except IndexError:
        vid_name = tracks[vid_index]
643

644
    coords = coords[vid_name]
645
    dists = dists[vid_name]
646
    speeds = speeds[vid_name]
647
    likelihoods = coordinates.get_quality()[vid_name]
648
649
650
651
652
    arena_abs = coordinates.get_arenas[1][0]

    # Dictionary with motives per frame
    tag_dict = {}

653
654
655
656
657
658
659
660
661
662
663
664
665
    # Bulk body parts
    main_body = [
        "Left_ear",
        "Right_ear",
        "Spine_1",
        "Center",
        "Spine_2",
        "Left_fhip",
        "Right_fhip",
        "Left_bhip",
        "Right_bhip",
    ]

666
667
    def onebyone_contact(bparts: List):
        """Returns a smooth boolean array with 1to1 contacts between two mice"""
668
        nonlocal coords, animal_ids, params, arena_abs, arena_params
669

670
671
672
673
674
        try:
            left = animal_ids[0] + bparts[0]
        except TypeError:
            left = [animal_ids[0] + "_" + suffix for suffix in bparts[0]]

675
676
677
678
679
        try:
            right = animal_ids[1] + bparts[-1]
        except TypeError:
            right = [animal_ids[1] + "_" + suffix for suffix in bparts[-1]]

680
        return deepof.utils.smooth_boolean_array(
681
682
            close_single_contact(
                coords,
683
684
                (left if not isinstance(left, list) else right),
                (right if not isinstance(left, list) else left),
lucas_miranda's avatar
lucas_miranda committed
685
                params["close_contact_tol"],
686
                arena_abs,
687
                arena_params[1][1],
688
689
            )
        )
690
691
692
693

    def twobytwo_contact(rev):
        """Returns a smooth boolean array with side by side contacts between two mice"""

694
        nonlocal coords, animal_ids, params, arena_abs, arena_params
695
        return deepof.utils.smooth_boolean_array(
696
697
698
699
700
701
            close_double_contact(
                coords,
                animal_ids[0] + "_Nose",
                animal_ids[0] + "_Tail_base",
                animal_ids[1] + "_Nose",
                animal_ids[1] + "_Tail_base",
lucas_miranda's avatar
lucas_miranda committed
702
                params["side_contact_tol"],
703
                rev=rev,
704
                arena_abs=arena_abs,
705
                arena_rel=arena_params[1][1],
706
707
            )
        )
708

709
    def overall_speed(ovr_speeds, _id, ucond):
lucas_miranda's avatar
lucas_miranda committed
710
711
712
713
714
715
716
717
718
719
720
721
722
        bparts = [
            "Center",
            "Spine_1",
            "Spine_2",
            "Nose",
            "Left_ear",
            "Right_ear",
            "Left_fhip",
            "Right_fhip",
            "Left_bhip",
            "Right_bhip",
            "Tail_base",
        ]
723
        array = ovr_speeds[[_id + ucond + bpart for bpart in bparts]]
lucas_miranda's avatar
lucas_miranda committed
724
725
726
        avg_speed = np.nanmedian(array[1:], axis=1)
        return np.insert(avg_speed, 0, np.nan, axis=0)

727
    if len(animal_ids) == 2:
728
729
730
731
732
733
734
        # Define behaviours that can be computed on the fly from the distance matrix
        tag_dict["nose2nose"] = onebyone_contact(bparts=["_Nose"])

        tag_dict["sidebyside"] = twobytwo_contact(rev=False)

        tag_dict["sidereside"] = twobytwo_contact(rev=True)

735
736
737
738
739
740
        tag_dict[animal_ids[0] + "_nose2tail"] = onebyone_contact(
            bparts=["_Nose", "_Tail_base"]
        )
        tag_dict[animal_ids[1] + "_nose2tail"] = onebyone_contact(
            bparts=["_Tail_base", "_Nose"]
        )
741
742
743
744
745
746
747
748
749
        tag_dict[animal_ids[0] + "_nose2body"] = onebyone_contact(
            bparts=[
                "_Nose",
                main_body,
            ]
        )
        tag_dict[animal_ids[1] + "_nose2body"] = onebyone_contact(
            bparts=[
                main_body,
750
                "_Nose",
751
752
            ]
        )
753

754
755
756
        for _id in animal_ids:
            tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
                following_path(
757
                    dists,
758
759
760
                    coords,
                    follower=_id,
                    followed=[i for i in animal_ids if i != _id][0],
lucas_miranda's avatar
lucas_miranda committed
761
762
                    frames=params["follow_frames"],
                    tol=params["follow_tol"],
763
764
765
                )
            )

766
767
    for _id in animal_ids:
        tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
768
769
            climb_wall(
                arena_type,
770
                arena_params,
771
                coords,
lucas_miranda's avatar
lucas_miranda committed
772
                params["climb_tol"],
773
774
                _id + undercond + "Nose",
            )
775
        )
776
777
778
779
        tag_dict[_id + undercond + "sniffing"] = deepof.utils.smooth_boolean_array(
            sniff_object(
                speeds,
                arena_type,
780
                arena_params,
781
782
783
784
785
786
787
788
                coords,
                params["climb_tol"],
                params["huddle_speed"],
                _id + undercond + "Nose",
                object="arena",
                animal_id=_id,
            )
        )
lucas_miranda's avatar
lucas_miranda committed
789
        tag_dict[_id + undercond + "speed"] = overall_speed(speeds, _id, undercond)
790
        tag_dict[_id + undercond + "huddle"] = deepof.utils.smooth_boolean_array(
lucas_miranda's avatar
lucas_miranda committed
791
792
793
            huddle(
                coords,
                speeds,
lucas_miranda's avatar
lucas_miranda committed
794
795
                params["huddle_forward"],
                params["huddle_speed"],
796
                animal_id=_id,
lucas_miranda's avatar
lucas_miranda committed
797
            )
798
        )
799
800
801
802
803
804
805
806
807
        tag_dict[_id + undercond + "dig"] = deepof.utils.smooth_boolean_array(
            dig(
                speeds,
                likelihoods,
                params["huddle_speed"],
                params["nose_likelihood"],
                animal_id=_id,
            )
        )
808
809
810
811
812
813
814
815
816
        tag_dict[_id + undercond + "lookaround"] = deepof.utils.smooth_boolean_array(
            look_around(
                speeds,
                likelihoods,
                params["huddle_speed"],
                params["nose_likelihood"],
                animal_id=_id,
            )
        )
817

818
    tag_df = pd.DataFrame(tag_dict).fillna(0)
819
820
821
822

    return tag_df


lucas_miranda's avatar
lucas_miranda committed
823
def tag_rulebased_frames(
824
825
826
827
828
829
830
831
832
833
834
835
    frame,
    font,
    frame_speeds,
    animal_ids,
    corners,
    tag_dict,
    fnum,
    undercond,
    hparams,
    arena,
    debug,
    coords,
lucas_miranda's avatar
lucas_miranda committed
836
):
837
    """Helper function for rule_based_video. Annotates a given frame with on-screen information
lucas_miranda's avatar
lucas_miranda committed
838
839
    about the recognised patterns"""

840
    arena, w, h = arena
lucas_miranda's avatar
lucas_miranda committed
841

lucas_miranda's avatar
lucas_miranda committed
842
843
    def write_on_frame(text, pos, col=(255, 255, 255)):
        """Partial closure over cv2.putText to avoid code repetition"""
844
        return cv2.putText(frame, text, pos, font, 0.75, col, 2)
lucas_miranda's avatar
lucas_miranda committed
845

846
847
848
849
850
851
852
    def conditional_flag():
        """Returns a tag depending on a condition"""
        if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
            return left_flag
        else:
            return right_flag

lucas_miranda's avatar
lucas_miranda committed
853
854
855
856
857
858
859
860
861
862
863
864
865
    def conditional_pos():
        """Returns a position depending on a condition"""
        if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
            return corners["downleft"]
        else:
            return corners["downright"]

    def conditional_col(cond=None):
        """Returns a colour depending on a condition"""
        if cond is None:
            cond = frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
        if cond:
            return 150, 255, 150
lucas_miranda's avatar
lucas_miranda committed
866
867
        else:
            return 150, 150, 255
lucas_miranda's avatar
lucas_miranda committed
868

869
870
871
872
873
    # Keep track of space usage in the output video
    # The flags are set to False as soon as the lower
    # corners are occupied with text
    left_flag, right_flag = True, True

874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
    if debug:
        # Print arena for debugging
        cv2.ellipse(frame, arena[0], arena[1], arena[2], 0, 360, (0, 255, 0), 3)
        # Print body parts for debuging
        for bpart in coords.columns.levels[0]:
            if not np.isnan(coords[bpart]["x"][fnum]):
                cv2.circle(
                    frame,
                    (int(coords[bpart]["x"][fnum]), int(coords[bpart]["y"][fnum])),
                    radius=3,
                    color=(
                        (255, 0, 0) if bpart.startswith(animal_ids[0]) else (0, 0, 255)
                    ),
                    thickness=-1,
                )
        # Print frame number
        write_on_frame("Frame " + str(fnum), (int(w * 0.3 / 10), int(h / 1.15)))
891

892
    if len(animal_ids) > 1:
893

894
        if tag_dict["nose2nose"][fnum]:
lucas_miranda's avatar
lucas_miranda committed
895
            write_on_frame("Nose-Nose", conditional_pos())
896
897
898
899
900
901
902
903
904
905
906
907
908
909
            if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
                left_flag = False
            else:
                right_flag = False

        if tag_dict[animal_ids[0] + "_nose2body"][fnum] and left_flag:
            write_on_frame("nose2body", corners["downleft"])
            left_flag = False

        if tag_dict[animal_ids[1] + "_nose2body"][fnum] and right_flag:
            write_on_frame("nose2body", corners["downright"])
            right_flag = False

        if tag_dict[animal_ids[0] + "_nose2tail"][fnum] and left_flag:
lucas_miranda's avatar
lucas_miranda committed
910
            write_on_frame("Nose-Tail", corners["downleft"])
911
912
913
            left_flag = False

        if tag_dict[animal_ids[1] + "_nose2tail"][fnum] and right_flag:
lucas_miranda's avatar
lucas_miranda committed
914
            write_on_frame("Nose-Tail", corners["downright"])
915
916
917
            right_flag = False

        if tag_dict["sidebyside"][fnum] and left_flag and conditional_flag():
lucas_miranda's avatar
lucas_miranda committed
918
            write_on_frame(
919
920
                "Side-side",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
921
            )
922
923
924
925
926
927
            if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
                left_flag = False
            else:
                right_flag = False

        if tag_dict["sidereside"][fnum] and left_flag and conditional_flag():
lucas_miranda's avatar
lucas_miranda committed
928
            write_on_frame(
929
930
                "Side-Rside",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
931
            )
932
933
934
935
936
            if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
                left_flag = False
            else:
                right_flag = False

937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
    zipped_pos = list(
        zip(
            animal_ids,
            [corners["downleft"], corners["downright"]],
            [corners["upleft"], corners["upright"]],
            [left_flag, right_flag],
        )
    )

    for _id, down_pos, up_pos, flag in zipped_pos:

        if flag:

            if tag_dict[_id + undercond + "climbing"][fnum]:
                write_on_frame("climbing", down_pos)
            elif tag_dict[_id + undercond + "huddle"][fnum]:
                write_on_frame("huddling", down_pos)
            elif tag_dict[_id + undercond + "sniffing"][fnum]:
                write_on_frame("sniffing", down_pos)
            elif tag_dict[_id + undercond + "dig"][fnum]:
                write_on_frame("digging", down_pos)
            elif tag_dict[_id + undercond + "lookaround"][fnum]:
                write_on_frame("lookaround", down_pos)

961
962
963
964
965
966
967
968
969
        #     if (
        #         tag_dict[_id + "_following"][fnum]
        #         and not tag_dict[_id + "_climbing"][fnum]
        #     ):
        #         write_on_frame(
        #             "*f",
        #             (int(w * 0.3 / 10), int(h / 10)),
        #             conditional_col(),
        #         )
lucas_miranda's avatar
lucas_miranda committed
970
971
972
973
974

        # Define the condition controlling the colour of the speed display
        if len(animal_ids) > 1:
            colcond = frame_speeds[_id] == max(list(frame_speeds.values()))
        else:
975
            colcond = hparams["huddle_speed"] < frame_speeds
lucas_miranda's avatar
lucas_miranda committed
976
977

        write_on_frame(
978
            str(
979
980
981
                np.round(
                    (frame_speeds if len(animal_ids) == 1 else frame_speeds[_id]), 2
                )
982
983
            )
            + " mmpf",
lucas_miranda's avatar
lucas_miranda committed
984
985
986
987
988
            up_pos,
            conditional_col(cond=colcond),
        )


lucas_miranda's avatar
lucas_miranda committed
989
# noinspection PyProtectedMember,PyDefaultArgument
990
def rule_based_video(
991
992
993
994
995
996
997
998
999
1000
    coordinates: Coordinates,
    tracks: List,
    videos: List,
    vid_index: int,
    tag_dict: pd.DataFrame,
    frame_limit: int = np.inf,
    recog_limit: int = 100,
    path: str = os.path.join("."),
    params: dict = {},
    debug: bool = False,
lucas_miranda's avatar
lucas_miranda committed
1001
) -> True:
1002
1003
1004
1005
1006
1007
    """Renders a version of the input video with all rule-based taggings in place.

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
        - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
1008
1009
        - debug (bool): if True, several debugging attributes (such as used body parts and arena) are plotted in
        the output video
1010
1011
1012
1013
        - vid_index (int): index in videos of the experiment to annotate
        - fps (float): frames per second of the analysed video. Same as input by default
        - path (str): directory in which the experimental data is stored
        - frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
1014
        - recog_limit (int): number of frames to use for arena recognition (100 by default)
lucas_miranda's avatar
lucas_miranda committed
1015
        - params (dict): dictionary to overwrite the default values of the hyperparameters of the functions
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        that the rule-based pose estimation utilizes. Values can be:
            - speed_pause (int): size of the rolling window to use when computing speeds
            - close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - follow_frames (int): number of frames during which the following trait is tracked
            - follow_tol (int): maximum distance between follower and followed's path during the last follow_frames,
            in order to report a detection
            - huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection
            - huddle_speed (int): maximum speed to report a huddle detection

    Returns:
        True

    """

lucas_miranda's avatar
lucas_miranda committed
1031
    # DATA OBTENTION AND PREPARATION
lucas_miranda's avatar
lucas_miranda committed
1032
    params = get_hparameters(params)
1033
    animal_ids = coordinates._animal_ids
lucas_miranda's avatar
lucas_miranda committed
1034
    undercond = "_" if len(animal_ids) > 1 else ""
1035

1036
    try:
1037
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
1038
1039
    except IndexError:
        vid_name = tracks[vid_index]
1040
1041

    arena, h, w = deepof.utils.recognize_arena(
1042
1043
1044
1045
1046
1047
        videos,
        vid_index,
        path,
        recog_limit,
        coordinates._arena,
        detection_mode=coordinates._arena_detection,
1048
        cnn_model=coordinates._ellipse_detection_model,
1049
    )
1050
    corners = frame_corners(h, w)
1051

lucas_miranda's avatar
lucas_miranda committed
1052
1053
1054
1055
    cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
    # Keep track of the frame number, to align with the tracking data
    fnum = 0
    writer = None
lucas_miranda's avatar
lucas_miranda committed
1056
1057
1058
    frame_speeds = (
        {_id: -np.inf for _id in animal_ids} if len(animal_ids) > 1 else -np.inf
    )
1059

lucas_miranda's avatar
lucas_miranda committed
1060
1061
    # Loop over the frames in the video
    while cap.isOpened() and fnum < frame_limit:
1062

lucas_miranda's avatar
lucas_miranda committed
1063
1064
1065
1066
1067
        ret, frame = cap.read()
        # if frame is read correctly ret is True
        if not ret:  # pragma: no cover
            print("Can't receive frame (stream end?). Exiting ...")
            break
1068

1069
        font = cv2.FONT_HERSHEY_DUPLEX
1070

lucas_miranda's avatar
lucas_miranda committed
1071
1072
1073
        # Capture speeds
        try:
            if (
1074
1075
                list(frame_speeds.values())[0] == -np.inf
                or fnum % params["speed_pause"] == 0
lucas_miranda's avatar
lucas_miranda committed
1076
1077
            ):
                for _id in animal_ids:
1078
                    frame_speeds[_id] = tag_dict[_id + undercond + "speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
1079
        except AttributeError:
lucas_miranda's avatar
lucas_miranda committed
1080
            if frame_speeds == -np.inf or fnum % params["speed_pause"] == 0:
1081
                frame_speeds = tag_dict["speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
1082
1083

        # Display all annotations in the output video
lucas_miranda's avatar
lucas_miranda committed
1084
1085
1086
1087
        tag_rulebased_frames(
            frame,
            font,
            frame_speeds,
lucas_miranda's avatar
lucas_miranda committed
1088
            animal_ids,
lucas_miranda's avatar
lucas_miranda committed
1089
1090
1091
1092
            corners,
            tag_dict,
            fnum,
            undercond,
lucas_miranda's avatar
lucas_miranda committed
1093
            params,
1094
            (arena, h, w),
1095
1096
            debug,
            coordinates.get_coords(center=False)[vid_name],
lucas_miranda's avatar
lucas_miranda committed
1097
1098
        )

lucas_miranda's avatar
lucas_miranda committed
1099
1100
1101
1102
1103
        if writer is None:
            # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
            # Define the FPS. Also frame size is passed.
            writer = cv2.VideoWriter()
            writer.open(
1104
                vid_name + "_tagged.avi",
lucas_miranda's avatar
lucas_miranda committed
1105
                cv2.VideoWriter_fourcc(*"MJPG"),
lucas_miranda's avatar
lucas_miranda committed
1106
                params["fps"],
lucas_miranda's avatar
lucas_miranda committed
1107
1108
1109
                (frame.shape[1], frame.shape[0]),
                True,
            )
1110

lucas_miranda's avatar
lucas_miranda committed
1111
        writer.write(frame)
lucas_miranda's avatar
lucas_miranda committed
1112
        fnum += 1
1113

lucas_miranda's avatar
lucas_miranda committed
1114
1115
    cap.release()
    cv2.destroyAllWindows()
lucas_miranda's avatar
lucas_miranda committed
1116
1117

    return True
1118

1119

1120
# TODO:
1121
#    - Is border sniffing anything you might consider interesting?