pose_utils.py 29.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# @author lucasmiranda42
# encoding: utf-8
# module deepof

"""

Functions and general utilities for rule-based pose estimation. See documentation for details

"""

import cv2
lucas_miranda's avatar
lucas_miranda committed
12
import deepof.utils
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import regex as re
import seaborn as sns
from itertools import combinations
from scipy import stats
from typing import Any, List, NewType

Coordinates = NewType("Coordinates", Any)


def close_single_contact(
lucas_miranda's avatar
lucas_miranda committed
27
28
29
30
31
32
    pos_dframe: pd.DataFrame,
    left: str,
    right: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
33
34
35
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

36
37
38
39
40
41
42
43
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left (string): First member of the potential contact
        - right (string): Second member of the potential contact
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
44

45
46
47
    Returns:
        - contact_array (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
48
49

    close_contact = (
lucas_miranda's avatar
lucas_miranda committed
50
51
        np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
    ) / arena_rel < tol
52
53
54
55
56

    return close_contact


def close_double_contact(
lucas_miranda's avatar
lucas_miranda committed
57
58
59
60
61
62
63
64
65
    pos_dframe: pd.DataFrame,
    left1: str,
    left2: str,
    right1: str,
    right2: str,
    tol: float,
    arena_abs: int,
    arena_rel: int,
    rev: bool = False,
66
67
68
) -> np.array:
    """Returns a boolean array that's True if the specified body parts are closer than tol.

69
70
71
72
73
74
75
76
77
78
79
    Parameters:
        - pos_dframe (pandas.DataFrame): DLC output as pandas.DataFrame; only applicable
        to two-animal experiments.
        - left1 (string): First contact point of animal 1
        - left2 (string): Second contact point of animal 1
        - right1 (string): First contact point of animal 2
        - right2 (string): Second contact point of animal 2
        - tol (float): maximum distance for which a contact is reported
        - arena_abs (int): length in mm of the diameter of the real arena
        - arena_rel (int): length in pixels of the diameter of the arena in the video
        - rev (bool): reverses the default behaviour (nose2tail contact for both mice)
80

81
82
83
    Returns:
        - double_contact (np.array): True if the distance between the two specified points
        is less than tol, False otherwise"""
84
85
86

    if rev:
        double_contact = (
lucas_miranda's avatar
lucas_miranda committed
87
88
89
90
91
92
93
94
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
95
96
97

    else:
        double_contact = (
lucas_miranda's avatar
lucas_miranda committed
98
99
100
101
102
103
104
105
            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
            / arena_rel
            < tol
        ) & (
            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
            / arena_rel
            < tol
        )
106
107
108
109
110

    return double_contact


def climb_wall(
111
112
113
114
115
    arena_type: str,
    arena: np.array,
    pos_dict: pd.DataFrame,
    tol: float,
    nose: str,
lucas_miranda's avatar
lucas_miranda committed
116
    centered_data: bool = False,
117
118
119
) -> np.array:
    """Returns True if the specified mouse is climbing the wall

120
121
122
123
124
125
126
127
128
    Parameters:
        - arena_type (str): arena type; must be one of ['circular']
        - arena (np.array): contains arena location and shape details
        - pos_dict (table_dict): position over time for all videos in a project
        - tol (float): minimum tolerance to report a hit
        - nose (str): indicates the name of the body part representing the nose of
        the selected animal
        - arena_dims (int): indicates radius of the real arena in mm
        - centered_data (bool): indicates whether the input data is centered
129

130
131
132
    Returns:
        - climbing (np.array): boolean array. True if selected animal
        is climbing the walls of the arena"""
133
134
135

    nose = pos_dict[nose]

136
137
138
139
140
141
142
143
    def rotate(origin, point, angle):
        ox, oy = origin
        px, py = point

        qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy)
        qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy)
        return qx, qy

144
    def outside_ellipse(x, y, e_center, e_axes, e_angle, threshold=0.0):
145
146
147
148
149
150
151

        x, y = rotate(e_center, (x, y), np.radians(e_angle))

        term_x = (x - e_center[0]) ** 2 / (e_axes[0] + threshold) ** 2
        term_y = (y - e_center[1]) ** 2 / (e_axes[1] + threshold) ** 2
        return term_x + term_y > 1

152
    if arena_type == "circular":
153
154
155
156
157
158
159
160
161
162
163
        center = np.zeros(2) if centered_data else np.array(arena[0])
        axes = arena[1]
        angle = arena[2]
        climbing = outside_ellipse(
            x=nose["x"],
            y=nose["y"],
            e_center=center,
            e_axes=axes,
            e_angle=-angle,
            threshold=tol,
        )
164
165
166
167
168
169
170
171

    else:
        raise NotImplementedError("Supported values for arena_type are ['circular']")

    return climbing


def huddle(
lucas_miranda's avatar
lucas_miranda committed
172
173
174
175
176
177
    pos_dframe: pd.DataFrame,
    speed_dframe: pd.DataFrame,
    tol_forward: float,
    tol_spine: float,
    tol_speed: float,
    animal_id: str = "",
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
) -> np.array:
    """Returns true when the mouse is huddling using simple rules. (!!!) Designed to
    work with deepof's default DLC mice models; not guaranteed to work otherwise.

        Parameters:
            - pos_dframe (pandas.DataFrame): position of body parts over time
            - speed_dframe (pandas.DataFrame): speed of body parts over time
            - tol_forward (float): Maximum tolerated distance between ears and
            forward limbs
            - tol_rear (float): Maximum tolerated average distance between spine
            body parts
            - tol_speed (float): Maximum tolerated speed for the center of the mouse

        Returns:
            hudd (np.array): True if the animal is huddling, False otherwise
193
    """
194
195
196
197
198

    if animal_id != "":
        animal_id += "_"

    forward = (
lucas_miranda's avatar
lucas_miranda committed
199
200
201
202
203
204
205
206
207
208
209
210
        np.linalg.norm(
            pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
            axis=1,
        )
        < tol_forward
    ) & (
        np.linalg.norm(
            pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
            axis=1,
        )
        < tol_forward
    )
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

    spine = [
        animal_id + "Spine_1",
        animal_id + "Center",
        animal_id + "Spine_2",
        animal_id + "Tail_base",
    ]
    spine_dists = []
    for comb in range(2):
        spine_dists.append(
            np.linalg.norm(
                pos_dframe[spine[comb]] - pos_dframe[spine[comb + 1]], axis=1
            )
        )
    spine = np.mean(spine_dists) < tol_spine
    speed = speed_dframe[animal_id + "Center"] < tol_speed
    hudd = forward & spine & speed

    return hudd


def following_path(
lucas_miranda's avatar
lucas_miranda committed
233
234
235
236
237
238
    distance_dframe: pd.DataFrame,
    position_dframe: pd.DataFrame,
    follower: str,
    followed: str,
    frames: int = 20,
    tol: float = 0,
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
) -> np.array:
    """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
    followed has walked over the last specified number of frames

        Parameters:
            - distance_dframe (pandas.DataFrame): distances between bodyparts; generated by the preprocess module
            - position_dframe (pandas.DataFrame): position of bodyparts; generated by the preprocess module
            - follower (str) identifier for the animal who's following
            - followed (str) identifier for the animal who's followed
            - frames (int) frames in which to track whether the process consistently occurs,
            - tol (float) Maximum distance for which True is returned

        Returns:
            - follow (np.array): boolean sequence, True if conditions are fulfilled, False otherwise"""

    # Check that follower is close enough to the path that followed has passed though in the last frames
    shift_dict = {
        i: position_dframe[followed + "_Tail_base"].shift(i) for i in range(frames)
    }
    dist_df = pd.DataFrame(
        {
            i: np.linalg.norm(
                position_dframe[follower + "_Nose"] - shift_dict[i], axis=1
            )
            for i in range(frames)
        }
    )

    # Check that the animals are oriented follower's nose -> followed's tail
    right_orient1 = (
lucas_miranda's avatar
lucas_miranda committed
269
270
271
272
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[
            tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
        ]
273
274
275
    )

    right_orient2 = (
lucas_miranda's avatar
lucas_miranda committed
276
277
        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
        < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
278
279
280
    )

    follow = np.all(
281
282
        np.array([(dist_df.min(axis=1) < tol), right_orient1, right_orient2]),
        axis=0,
283
284
285
286
287
288
    )

    return follow


def single_behaviour_analysis(
lucas_miranda's avatar
lucas_miranda committed
289
290
291
292
293
294
295
    behaviour_name: str,
    treatment_dict: dict,
    behavioural_dict: dict,
    plot: int = 0,
    stat_tests: bool = True,
    save: str = None,
    ylim: float = None,
296
297
) -> list:
    """Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
298
    with the actual tags, outputs a box plot and a series of significance tests amongst the groups
299

300
301
302
303
304
305
306
307
     Parameters:
         - behaviour_name (str): name of the behavioural trait to analize
         - treatment_dict (dict): dictionary containing video names as keys and experimental conditions as values
         - behavioural_dict (dict): tagged dictionary containing video names as keys and annotations as values
         - plot (int): Silent if 0; otherwise, indicates the dpi of the figure to plot
         - stat_tests (bool): performs FDR corrected Mann-U non-parametric tests among all groups if True
         - save (str): Saves the produced figure to the specified file
         - ylim (float): y-limit for the boxplot. Ignored if plot == False
308

309
310
311
     Returns:
         - beh_dict (dict): dictionary containing experimental conditions as keys and video names as values
         - stat_dict (dict): dictionary containing condition pairs as keys and stat results as values"""
312
313
314
315
316

    beh_dict = {condition: [] for condition in treatment_dict.keys()}

    for condition in beh_dict.keys():
        for ind in treatment_dict[condition]:
317
318
            beh_dict[condition] += np.sum(behavioural_dict[ind][behaviour_name]) / len(
                behavioural_dict[ind][behaviour_name]
319
320
321
322
323
324
325
326
327
            )

    return_list = [beh_dict]

    if plot > 0:

        fig, ax = plt.subplots(dpi=plot)

        sns.boxplot(
328
329
330
331
            x=list(beh_dict.keys()),
            y=list(beh_dict.values()),
            orient="vertical",
            ax=ax,
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        )

        ax.set_title("{} across groups".format(behaviour_name))
        ax.set_ylabel("Proportion of frames")

        if ylim is not None:
            ax.set_ylim(ylim)

        if save is not None:  # pragma: no cover
            plt.savefig(save)

        return_list.append(fig)

    if stat_tests:
        stat_dict = {}
        for i in combinations(treatment_dict.keys(), 2):
            # Solves issue with automatically generated examples
lucas_miranda's avatar
lucas_miranda committed
349
350
351
352
353
354
355
356
            if np.any(
                np.array(
                    [
                        beh_dict[i[0]] == beh_dict[i[1]],
                        np.var(beh_dict[i[0]]) == 0,
                        np.var(beh_dict[i[1]]) == 0,
                    ]
                )
357
358
359
360
361
362
363
364
365
366
367
368
            ):
                stat_dict[i] = "Identical sources. Couldn't run"
            else:
                stat_dict[i] = stats.mannwhitneyu(
                    beh_dict[i[0]], beh_dict[i[1]], alternative="two-sided"
                )
        return_list.append(stat_dict)

    return return_list


def max_behaviour(
lucas_miranda's avatar
lucas_miranda committed
369
    behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
370
371
372
) -> np.array:
    """Returns the most frequent behaviour in a window of window_size frames

373
374
375
376
377
378
    Parameters:
            - behaviour_dframe (pd.DataFrame): boolean matrix containing occurrence
            of tagged behaviours per frame in the video
            - window_size (int): size of the window to use when computing
            the maximum behaviour per time slot
            - stepped (bool): sliding windows don't overlap if True. False by default
379

380
381
382
    Returns:
        - max_array (np.array): string array with the most common behaviour per instance
        of the sliding window"""
383
384
385
386
387
388
389
390
391
392
393
394

    speeds = [col for col in behaviour_dframe.columns if "speed" in col.lower()]

    behaviour_dframe = behaviour_dframe.drop(speeds, axis=1).astype("float")
    win_array = behaviour_dframe.rolling(window_size, center=True).sum()
    if stepped:
        win_array = win_array[::window_size]
    max_array = win_array[1:].idxmax(axis=1)

    return np.array(max_array)


395
# noinspection PyDefaultArgument
lucas_miranda's avatar
lucas_miranda committed
396
397
398
def get_hparameters(hparams: dict = {}) -> dict:
    """Returns the most frequent behaviour in a window of window_size frames

399
400
    Parameters:
        - hparams (dict): dictionary containing hyperparameters to overwrite
lucas_miranda's avatar
lucas_miranda committed
401

402
    Returns:
403
        - defaults (dict): dictionary with overwritten parameters. Those not
404
        specified in the input retain their default values"""
lucas_miranda's avatar
lucas_miranda committed
405
406

    defaults = {
407
        "speed_pause": 3,
408
        "climb_tol": 10,
409
410
411
412
        "close_contact_tol": 35,
        "side_contact_tol": 80,
        "follow_frames": 10,
        "follow_tol": 5,
lucas_miranda's avatar
lucas_miranda committed
413
414
        "huddle_forward": 15,
        "huddle_spine": 10,
lucas_miranda's avatar
lucas_miranda committed
415
        "huddle_speed": 0.1,
lucas_miranda's avatar
lucas_miranda committed
416
        "fps": 24,
lucas_miranda's avatar
lucas_miranda committed
417
    }
418

lucas_miranda's avatar
lucas_miranda committed
419
420
    for k, v in hparams.items():
        defaults[k] = v
421

lucas_miranda's avatar
lucas_miranda committed
422
423
424
    return defaults


425
426
427
428
# noinspection PyDefaultArgument
def frame_corners(w, h, corners: dict = {}):
    """Returns a dictionary with the corner positions of the video frame

429
430
431
432
    Parameters:
        - w (int): width of the frame in pixels
        - h (int): height of the frame in pixels
        - corners (dict): dictionary containing corners to overwrite
433

434
435
436
    Returns:
        - defaults (dict): dictionary with overwriten parameters. Those not
        specified in the input retain their default values"""
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    defaults = {
        "downleft": (int(w * 0.3 / 10), int(h / 1.05)),
        "downright": (int(w * 6.5 / 10), int(h / 1.05)),
        "upleft": (int(w * 0.3 / 10), int(h / 20)),
        "upright": (int(w * 6.3 / 10), int(h / 20)),
    }

    for k, v in corners.items():
        defaults[k] = v

    return defaults


451
# noinspection PyDefaultArgument,PyProtectedMember
452
def rule_based_tagging(
lucas_miranda's avatar
lucas_miranda committed
453
454
455
    tracks: List,
    videos: List,
    coordinates: Coordinates,
456
    coords: Any,
457
    dists: Any,
458
    speeds: Any,
lucas_miranda's avatar
lucas_miranda committed
459
    vid_index: int,
460
    arena_type: str,
lucas_miranda's avatar
lucas_miranda committed
461
462
463
    recog_limit: int = 1,
    path: str = os.path.join("."),
    hparams: dict = {},
464
465
466
467
468
469
470
471
) -> pd.DataFrame:
    """Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
    video displaying the information in real time

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
        - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
472
        - coords (deepof.preprocessing.table_dict): table_dict with already processed coordinates
473
        - dists (deepof.preprocessing.table_dict): table_dict with already processed distances
474
        - speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds
475
        - vid_index (int): index in videos of the experiment to annotate
lucas_miranda's avatar
lucas_miranda committed
476
        - path (str): directory in which the experimental data is stored
477
        - recog_limit (int): number of frames to use for arena recognition (1 by default)
lucas_miranda's avatar
lucas_miranda committed
478
479
480
481
482
483
484
485
486
487
488
        - hparams (dict): dictionary to overwrite the default values of the hyperparameters of the functions
        that the rule-based pose estimation utilizes. Values can be:
            - speed_pause (int): size of the rolling window to use when computing speeds
            - close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - follow_frames (int): number of frames during which the following trait is tracked
            - follow_tol (int): maximum distance between follower and followed's path during the last follow_frames,
            in order to report a detection
            - huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection
            - huddle_spine (int): maximum average distance between spine body parts to report a huddle detection
            - huddle_speed (int): maximum speed to report a huddle detection
489
490
491
492
493

    Returns:
        - tag_df (pandas.DataFrame): table with traits as columns and frames as rows. Each
        value is a boolean indicating trait detection at a given time"""

lucas_miranda's avatar
lucas_miranda committed
494
    hparams = get_hparameters(hparams)
495
    animal_ids = coordinates._animal_ids
496
    undercond = "_" if len(animal_ids) > 1 else ""
lucas_miranda's avatar
lucas_miranda committed
497

498
    try:
499
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
500
501
    except IndexError:
        vid_name = tracks[vid_index]
502

503
    coords = coords[vid_name]
504
    dists = dists[vid_name]
505
    speeds = speeds[vid_name]
506
507
    arena_abs = coordinates.get_arenas[1][0]
    arena, h, w = deepof.utils.recognize_arena(
508
        videos, vid_index, path, recog_limit, coordinates._arena
509
510
511
512
513
    )

    # Dictionary with motives per frame
    tag_dict = {}

514
515
516
    def onebyone_contact(bparts: List):
        """Returns a smooth boolean array with 1to1 contacts between two mice"""
        nonlocal coords, animal_ids, hparams, arena_abs, arena
517

518
        return deepof.utils.smooth_boolean_array(
519
520
            close_single_contact(
                coords,
521
522
                animal_ids[0] + bparts[0],
                animal_ids[1] + bparts[-1],
lucas_miranda's avatar
lucas_miranda committed
523
                hparams["close_contact_tol"],
524
                arena_abs,
525
                arena[1][1],
526
527
            )
        )
528
529
530
531
532
533

    def twobytwo_contact(rev):
        """Returns a smooth boolean array with side by side contacts between two mice"""

        nonlocal coords, animal_ids, hparams, arena_abs, arena
        return deepof.utils.smooth_boolean_array(
534
535
536
537
538
539
            close_double_contact(
                coords,
                animal_ids[0] + "_Nose",
                animal_ids[0] + "_Tail_base",
                animal_ids[1] + "_Nose",
                animal_ids[1] + "_Tail_base",
lucas_miranda's avatar
lucas_miranda committed
540
                hparams["side_contact_tol"],
541
                rev=rev,
542
                arena_abs=arena_abs,
543
                arena_rel=arena[1][1],
544
545
            )
        )
546

547
    if len(animal_ids) == 2:
548
549
550
551
552
553
554
        # Define behaviours that can be computed on the fly from the distance matrix
        tag_dict["nose2nose"] = onebyone_contact(bparts=["_Nose"])

        tag_dict["sidebyside"] = twobytwo_contact(rev=False)

        tag_dict["sidereside"] = twobytwo_contact(rev=True)

555
556
557
558
559
560
        tag_dict[animal_ids[0] + "_nose2tail"] = onebyone_contact(
            bparts=["_Nose", "_Tail_base"]
        )
        tag_dict[animal_ids[1] + "_nose2tail"] = onebyone_contact(
            bparts=["_Tail_base", "_Nose"]
        )
561

562
563
564
        for _id in animal_ids:
            tag_dict[_id + "_following"] = deepof.utils.smooth_boolean_array(
                following_path(
565
                    dists,
566
567
568
                    coords,
                    follower=_id,
                    followed=[i for i in animal_ids if i != _id][0],
lucas_miranda's avatar
lucas_miranda committed
569
570
                    frames=hparams["follow_frames"],
                    tol=hparams["follow_tol"],
571
572
573
                )
            )

574
575
    for _id in animal_ids:
        tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
576
577
578
579
580
581
582
            climb_wall(
                arena_type,
                arena,
                coords,
                hparams["climb_tol"],
                _id + undercond + "Nose",
            )
583
        )
584
585
        tag_dict[_id + undercond + "speed"] = speeds[_id + undercond + "Center"]
        tag_dict[_id + undercond + "huddle"] = deepof.utils.smooth_boolean_array(
lucas_miranda's avatar
lucas_miranda committed
586
587
588
589
590
591
            huddle(
                coords,
                speeds,
                hparams["huddle_forward"],
                hparams["huddle_spine"],
                hparams["huddle_speed"],
592
                animal_id=_id,
lucas_miranda's avatar
lucas_miranda committed
593
            )
594
595
        )

596
597
598
599
600
    tag_df = pd.DataFrame(tag_dict)

    return tag_df


lucas_miranda's avatar
lucas_miranda committed
601
602
603
604
605
606
607
608
def tag_rulebased_frames(
    frame,
    font,
    frame_speeds,
    animal_ids,
    corners,
    tag_dict,
    fnum,
lucas_miranda's avatar
lucas_miranda committed
609
    dims,
lucas_miranda's avatar
lucas_miranda committed
610
611
    undercond,
    hparams,
612
    arena,
613
    debug,
614
    coords,
lucas_miranda's avatar
lucas_miranda committed
615
):
616
    """Helper function for rule_based_video. Annotates a given frame with on-screen information
lucas_miranda's avatar
lucas_miranda committed
617
618
619
    about the recognised patterns"""

    w, h = dims
620
    arena, h, w = arena
lucas_miranda's avatar
lucas_miranda committed
621

lucas_miranda's avatar
lucas_miranda committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    def write_on_frame(text, pos, col=(255, 255, 255)):
        """Partial closure over cv2.putText to avoid code repetition"""
        return cv2.putText(frame, text, pos, font, 1, col, 2)

    def conditional_pos():
        """Returns a position depending on a condition"""
        if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
            return corners["downleft"]
        else:
            return corners["downright"]

    def conditional_col(cond=None):
        """Returns a colour depending on a condition"""
        if cond is None:
            cond = frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
        if cond:
            return 150, 150, 255
        else:
            return 150, 255, 150

642
643
644
645
646
647
648
    zipped_pos = list(
        zip(
            animal_ids,
            [corners["downleft"], corners["downright"]],
            [corners["upleft"], corners["upright"]],
        )
    )
lucas_miranda's avatar
lucas_miranda committed
649
650

    if len(animal_ids) > 1:
651

652
        if debug:
653
            # Print arena for debugging
654
            cv2.ellipse(frame, arena[0], arena[1], arena[2], 0, 360, (0, 255, 0), 3)
655
656
657
658
659
660
661
662
663
664
665
            # Print body parts for debuging
            for bpart in coords.columns.levels[0]:
                cv2.circle(
                    frame,
                    (int(coords[bpart]["x"][fnum]), int(coords[bpart]["y"][fnum])),
                    radius=3,
                    color=(
                        (255, 0, 0) if bpart.startswith(animal_ids[0]) else (0, 0, 255)
                    ),
                    thickness=-1,
                )
666
667
            # Print frame number
            write_on_frame("Frame " + str(fnum), corners["downleft"])
668

lucas_miranda's avatar
lucas_miranda committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
            write_on_frame("Nose-Nose", conditional_pos())
        if (
            tag_dict[animal_ids[0] + "_nose2tail"][fnum]
            and not tag_dict["sidereside"][fnum]
        ):
            write_on_frame("Nose-Tail", corners["downleft"])
        if (
            tag_dict[animal_ids[1] + "_nose2tail"][fnum]
            and not tag_dict["sidereside"][fnum]
        ):
            write_on_frame("Nose-Tail", corners["downright"])
        if tag_dict["sidebyside"][fnum]:
            write_on_frame(
683
684
                "Side-side",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
685
686
687
            )
        if tag_dict["sidereside"][fnum]:
            write_on_frame(
688
689
                "Side-Rside",
                conditional_pos(),
lucas_miranda's avatar
lucas_miranda committed
690
            )
691
692
693
694
695
696
697
698
699
700
        # for _id, down_pos, up_pos in zipped_pos:
        #     if (
        #         tag_dict[_id + "_following"][fnum]
        #         and not tag_dict[_id + "_climbing"][fnum]
        #     ):
        #         write_on_frame(
        #             "*f",
        #             (int(w * 0.3 / 10), int(h / 10)),
        #             conditional_col(),
        #         )
lucas_miranda's avatar
lucas_miranda committed
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718

    for _id, down_pos, up_pos in zipped_pos:

        if tag_dict[_id + undercond + "climbing"][fnum]:
            write_on_frame("Climbing", down_pos)
        if (
            tag_dict[_id + undercond + "huddle"][fnum]
            and not tag_dict[_id + undercond + "climbing"][fnum]
        ):
            write_on_frame("huddle", down_pos)

        # Define the condition controlling the colour of the speed display
        if len(animal_ids) > 1:
            colcond = frame_speeds[_id] == max(list(frame_speeds.values()))
        else:
            colcond = hparams["huddle_speed"] > frame_speeds

        write_on_frame(
719
            str(
720
721
722
                np.round(
                    (frame_speeds if len(animal_ids) == 1 else frame_speeds[_id]), 2
                )
723
724
            )
            + " mmpf",
lucas_miranda's avatar
lucas_miranda committed
725
726
727
728
729
            up_pos,
            conditional_col(cond=colcond),
        )


lucas_miranda's avatar
lucas_miranda committed
730
# noinspection PyProtectedMember,PyDefaultArgument
731
def rule_based_video(
lucas_miranda's avatar
lucas_miranda committed
732
733
734
735
736
737
738
739
740
    coordinates: Coordinates,
    tracks: List,
    videos: List,
    vid_index: int,
    tag_dict: pd.DataFrame,
    frame_limit: int = np.inf,
    recog_limit: int = 1,
    path: str = os.path.join("."),
    hparams: dict = {},
741
    debug: bool = False,
lucas_miranda's avatar
lucas_miranda committed
742
) -> True:
743
744
745
746
747
748
    """Renders a version of the input video with all rule-based taggings in place.

    Parameters:
        - tracks (list): list containing experiment IDs as strings
        - videos (list): list of videos to load, in the same order as tracks
        - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
749
750
        - debug (bool): if True, several debugging attributes (such as used body parts and arena) are plotted in
        the output video
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
        - vid_index (int): index in videos of the experiment to annotate
        - fps (float): frames per second of the analysed video. Same as input by default
        - path (str): directory in which the experimental data is stored
        - frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
        - recog_limit (int): number of frames to use for arena recognition (1 by default)
        - hparams (dict): dictionary to overwrite the default values of the hyperparameters of the functions
        that the rule-based pose estimation utilizes. Values can be:
            - speed_pause (int): size of the rolling window to use when computing speeds
            - close_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - side_contact_tol (int): maximum distance between single bodyparts that can be used to report the trait
            - follow_frames (int): number of frames during which the following trait is tracked
            - follow_tol (int): maximum distance between follower and followed's path during the last follow_frames,
            in order to report a detection
            - huddle_forward (int): maximum distance between ears and forward limbs to report a huddle detection
            - huddle_spine (int): maximum average distance between spine body parts to report a huddle detection
            - huddle_speed (int): maximum speed to report a huddle detection

    Returns:
        True

    """

lucas_miranda's avatar
lucas_miranda committed
773
    # DATA OBTENTION AND PREPARATION
lucas_miranda's avatar
lucas_miranda committed
774
    hparams = get_hparameters(hparams)
775
    animal_ids = coordinates._animal_ids
lucas_miranda's avatar
lucas_miranda committed
776
    undercond = "_" if len(animal_ids) > 1 else ""
777

778
    try:
779
        vid_name = re.findall("(.*)DLC", tracks[vid_index])[0]
780
781
    except IndexError:
        vid_name = tracks[vid_index]
782
783
784
785

    arena, h, w = deepof.utils.recognize_arena(
        videos, vid_index, path, recog_limit, coordinates._arena
    )
786
    corners = frame_corners(h, w)
787

lucas_miranda's avatar
lucas_miranda committed
788
789
790
791
    cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
    # Keep track of the frame number, to align with the tracking data
    fnum = 0
    writer = None
lucas_miranda's avatar
lucas_miranda committed
792
793
794
    frame_speeds = (
        {_id: -np.inf for _id in animal_ids} if len(animal_ids) > 1 else -np.inf
    )
795

lucas_miranda's avatar
lucas_miranda committed
796
797
    # Loop over the frames in the video
    while cap.isOpened() and fnum < frame_limit:
798

lucas_miranda's avatar
lucas_miranda committed
799
800
801
802
803
        ret, frame = cap.read()
        # if frame is read correctly ret is True
        if not ret:  # pragma: no cover
            print("Can't receive frame (stream end?). Exiting ...")
            break
804

lucas_miranda's avatar
lucas_miranda committed
805
        font = cv2.FONT_HERSHEY_COMPLEX_SMALL
806

lucas_miranda's avatar
lucas_miranda committed
807
808
809
        # Capture speeds
        try:
            if (
lucas_miranda's avatar
lucas_miranda committed
810
811
                list(frame_speeds.values())[0] == -np.inf
                or fnum % hparams["speed_pause"] == 0
lucas_miranda's avatar
lucas_miranda committed
812
813
            ):
                for _id in animal_ids:
814
                    frame_speeds[_id] = tag_dict[_id + undercond + "speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
815
816
        except AttributeError:
            if frame_speeds == -np.inf or fnum % hparams["speed_pause"] == 0:
817
                frame_speeds = tag_dict["speed"][fnum]
lucas_miranda's avatar
lucas_miranda committed
818
819

        # Display all annotations in the output video
lucas_miranda's avatar
lucas_miranda committed
820
821
822
823
        tag_rulebased_frames(
            frame,
            font,
            frame_speeds,
lucas_miranda's avatar
lucas_miranda committed
824
            animal_ids,
lucas_miranda's avatar
lucas_miranda committed
825
826
827
            corners,
            tag_dict,
            fnum,
lucas_miranda's avatar
lucas_miranda committed
828
            (w, h),
lucas_miranda's avatar
lucas_miranda committed
829
830
            undercond,
            hparams,
831
            (arena, h, w),
832
833
            debug,
            coordinates.get_coords(center=False)[vid_name],
lucas_miranda's avatar
lucas_miranda committed
834
835
        )

lucas_miranda's avatar
lucas_miranda committed
836
837
838
839
840
        if writer is None:
            # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
            # Define the FPS. Also frame size is passed.
            writer = cv2.VideoWriter()
            writer.open(
841
                vid_name + "_tagged.avi",
lucas_miranda's avatar
lucas_miranda committed
842
843
844
845
846
                cv2.VideoWriter_fourcc(*"MJPG"),
                hparams["fps"],
                (frame.shape[1], frame.shape[0]),
                True,
            )
847

lucas_miranda's avatar
lucas_miranda committed
848
        writer.write(frame)
lucas_miranda's avatar
lucas_miranda committed
849
        fnum += 1
850

lucas_miranda's avatar
lucas_miranda committed
851
852
    cap.release()
    cv2.destroyAllWindows()
lucas_miranda's avatar
lucas_miranda committed
853
854

    return True
855

856

857
# TODO:
858
#    - Is border sniffing anything you might consider interesting?