data.py 39.6 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

17
from collections import defaultdict
18
from joblib import delayed, Parallel, parallel_backend
19
from typing import Dict, List, Tuple, Union
lucas_miranda's avatar
lucas_miranda committed
20
from multiprocessing import cpu_count
21
22
23
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
from sklearn.manifold import TSNE
24
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
25
from tqdm import tqdm
26
27
28
29
30
import deepof.pose_utils
import deepof.utils
import deepof.visuals
import matplotlib.pyplot as plt
import numpy as np
31
import os
32
33
import pandas as pd
import warnings
34

35
36
# DEFINE CUSTOM ANNOTATED TYPES #

37
38
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any)
39
40

# CLASSES FOR PREPROCESSING AND DATA WRANGLING
41

42

43
class project:
lucas_miranda's avatar
lucas_miranda committed
44
45
    """

46
47
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
48
49

    """
50
51

    def __init__(
52
        self,
53
        animal_ids: List = tuple([""]),
54
55
        arena: str = "circular",
        arena_dims: tuple = (1,),
56
57
        exclude_bodyparts: List = tuple([""]),
        exp_conditions: dict = None,
58
        interpolate_outliers: bool = True,
59
        interpolation_limit: int = 5,
lucas_miranda's avatar
lucas_miranda committed
60
        interpolation_std: int = 5,
61
        likelihood_tol: float = 0.75,
62
        model: str = "mouse_topview",
63
64
        path: str = deepof.utils.os.path.join("."),
        smooth_alpha: float = 0.99,
65
        table_format: str = "autodetect",
66
        video_format: str = ".mp4",
67
    ):
lucas_miranda's avatar
lucas_miranda committed
68

69
        self.path = path
70
71
        self.video_path = os.path.join(self.path, "Videos")
        self.table_path = os.path.join(self.path, "Tables")
72

73
        self.table_format = table_format
74
        if self.table_format == "autodetect":
75
            ex = [i for i in os.listdir(self.table_path) if not i.startswith(".")][0]
76
77
78
79
            if ".h5" in ex:
                self.table_format = ".h5"
            elif ".csv" in ex:
                self.table_format = ".csv"
80

81
        self.videos = sorted(
82
83
84
            [
                vid
                for vid in deepof.utils.os.listdir(self.video_path)
85
                if vid.endswith(video_format) and not vid.startswith(".")
86
            ]
87
88
        )
        self.tables = sorted(
89
90
91
            [
                tab
                for tab in deepof.utils.os.listdir(self.table_path)
92
                if tab.endswith(self.table_format) and not tab.startswith(".")
93
            ]
94
        )
95
        self.angles = True
96
        self.animal_ids = animal_ids
97
98
        self.arena = arena
        self.arena_dims = arena_dims
99
100
        self.distances = "all"
        self.ego = False
101
        self.exp_conditions = exp_conditions
102
103
104
105
        self.interpolate_outliers = interpolate_outliers
        self.interpolation_limit = interpolation_limit
        self.interpolation_std = interpolation_std
        self.likelihood_tolerance = likelihood_tol
106
        self.scales = self.get_scale
107
108
        self.smooth_alpha = smooth_alpha
        self.subset_condition = None
109
        self.video_format = video_format
110

lucas_miranda's avatar
lucas_miranda committed
111
112
113
        model_dict = {
            "mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
        }
lucas_miranda's avatar
lucas_miranda committed
114
        self.connectivity = model_dict[model]
115
116
117
118
        self.exclude_bodyparts = exclude_bodyparts
        if self.exclude_bodyparts != tuple([""]):
            for bp in exclude_bodyparts:
                self.connectivity.remove_node(bp)
lucas_miranda's avatar
lucas_miranda committed
119

120
121
    def __str__(self):
        if self.exp_conditions:
122
            return "deepof analysis of {} videos across {} conditions".format(
123
                len(self.videos), len(set(self.exp_conditions.values()))
124
125
            )
        else:
126
            return "deepof analysis of {} videos".format(len(self.videos))
127

128
129
130
131
132
133
134
135
    @property
    def subset_condition(self):
        """Sets a subset condition for the videos to load. If set,
        only the videos with the included pattern will be loaded"""
        return self._subset_condition

    @property
    def distances(self):
136
        """List. If not 'all', sets the body parts among which the
137
138
139
140
141
142
143
144
145
146
147
148
        distances will be computed"""
        return self._distances

    @property
    def ego(self):
        """String, name of a body part. If True, computes only the distances
        between the specified body part and the rest"""
        return self._ego

    @property
    def angles(self):
        """Bool. Toggles angle computation. True by default. If turned off,
149
        enhances performance for big datasets"""
150
151
152
153
154
155
156
157
158
159
        return self._angles

    @property
    def get_scale(self) -> np.array:
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

            scales = []
            for vid_index, _ in enumerate(self.videos):
160
161
162
163
164
165
166
167

                ellipse = deepof.utils.recognize_arena(
                    self.videos,
                    vid_index,
                    path=self.video_path,
                    arena_type=self.arena,
                )[0]

168
                scales.append(
169
                    list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1]]) * 2)
170
171
172
173
174
175
176
177
                    + list(self.arena_dims)
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

178
    def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
179
180
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
181
182
183
184
185
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

186
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
187
            print("Loading trajectories...")
188

lucas_miranda's avatar
lucas_miranda committed
189
190
        tab_dict = {}

191
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
192
193

            tab_dict = {
194
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_hdf(
195
                    deepof.utils.os.path.join(self.table_path, tab), dtype=float
196
197
198
199
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
200

201
202
            tab_dict = {
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_csv(
203
                    deepof.utils.os.path.join(self.table_path, tab),
204
205
206
                    header=[0, 1, 2],
                    index_col=0,
                    dtype=float,
207
                )
208
209
                for tab in self.tables
            }
210
211
212

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
213
        for key, value in tab_dict.items():
214
215
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
216
            lik = value.xs("likelihood", level="coords", axis=1, drop_level=True)
217

lucas_miranda's avatar
lucas_miranda committed
218
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
219
            lik_dict[key] = lik.droplevel("scorer", axis=1)
220
221
222

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
223
224
225
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
226
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
227
228
                cols = tab.columns
                smooth = pd.DataFrame(
229
230
231
                    deepof.utils.smooth_mult_trajectory(
                        np.array(tab), alpha=self.smooth_alpha
                    )
232
                )
lucas_miranda's avatar
lucas_miranda committed
233
                smooth.columns = cols
234
                tab_dict[key] = smooth.iloc[1:, :].reset_index(drop=True)
235

lucas_miranda's avatar
lucas_miranda committed
236
237
        for key, tab in tab_dict.items():
            tab_dict[key] = tab[tab.columns.levels[0][0]]
238

239
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
240
            for key, value in tab_dict.items():
241
242
243
244
245
246
247
248
249
250
251
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
252
253
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
254
255
256

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
257
                tab_dict[key] = tab
258

259
260
261
262
263
264
265
266
267
268
        if self.exclude_bodyparts != tuple([""]):

            for k, value in tab_dict.items():
                temp = value.drop(self.exclude_bodyparts, axis=1, level="bodyparts")
                temp.sort_index(axis=1, inplace=True)
                temp.columns = pd.MultiIndex.from_product(
                    [sorted(list(set([i[j] for i in temp.columns]))) for j in range(2)]
                )
                tab_dict[k] = temp.sort_index(axis=1)

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        if self.interpolate_outliers:

            if verbose:
                print("Interpolating outliers...")

            for k, value in tab_dict.items():
                tab_dict[k] = deepof.utils.interpolate_outliers(
                    value,
                    lik_dict[k],
                    likelihood_tolerance=self.likelihood_tolerance,
                    mode="or",
                    limit=self.interpolation_limit,
                    n_std=self.interpolation_std,
                )

lucas_miranda's avatar
lucas_miranda committed
284
        return tab_dict, lik_dict
285

286
287
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
288
        If ego is provided, it only returns distances to a specified bodypart"""
289

lucas_miranda's avatar
lucas_miranda committed
290
291
292
        if verbose:
            print("Computing distances...")

293
        nodes = self.distances
294
        if nodes == "all":
lucas_miranda's avatar
lucas_miranda committed
295
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
296
297

        assert [
lucas_miranda's avatar
lucas_miranda committed
298
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
299
300
301
302
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

303
        distance_dict = {
304
305
306
307
308
            key: deepof.utils.bpart_distance(
                tab,
                scales[i, 1],
                scales[i, 0],
            )
lucas_miranda's avatar
lucas_miranda committed
309
            for i, (key, tab) in enumerate(tab_dict.items())
310
        }
311

lucas_miranda's avatar
lucas_miranda committed
312
313
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
314
315
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
316

317
318
319
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
320
321
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
322
323
324

        return distance_dict

325
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
326
327
328
329
330
331
332
333
334
335
336
337
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
338
339
340
        if verbose:
            print("Computing angles...")

341
        cliques = deepof.utils.nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
342
343
344
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
345
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
346
347
348
349

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
350
351
352
                    deepof.utils.angle_trio(
                        np.array(tab[clique]).reshape([3, tab.shape[0], 2])
                    )
lucas_miranda's avatar
lucas_miranda committed
353
354
355
356
357
358
359
360
361
362
363
364
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
365

366
    def run(self, verbose: bool = True) -> Coordinates:
367
368
369
370
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
371
        angles = None
372
373

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
374
            distances = self.get_distances(tables, verbose)
375

lucas_miranda's avatar
lucas_miranda committed
376
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
377
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
378

lucas_miranda's avatar
lucas_miranda committed
379
        if verbose:
380
381
382
            print("Done!")

        return coordinates(
383
384
            angles=angles,
            animal_ids=self.animal_ids,
lucas_miranda's avatar
lucas_miranda committed
385
386
387
            arena=self.arena,
            arena_dims=self.arena_dims,
            distances=distances,
388
            exp_conditions=self.exp_conditions,
389
            path=self.path,
390
391
392
393
            quality=quality,
            scales=self.scales,
            tables=tables,
            videos=self.videos,
394
395
        )

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    @subset_condition.setter
    def subset_condition(self, value):
        self._subset_condition = value

    @distances.setter
    def distances(self, value):
        self._distances = value

    @ego.setter
    def ego(self, value):
        self._ego = value

    @angles.setter
    def angles(self, value):
        self._angles = value

412
413

class coordinates:
414
415
416
417
418
419
420
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

421
    def __init__(
422
        self,
423
424
        arena: str,
        arena_dims: np.array,
425
        path: str,
426
427
428
429
        quality: dict,
        scales: np.array,
        tables: dict,
        videos: list,
430
        angles: dict = None,
431
        animal_ids: List = tuple([""]),
432
433
        distances: dict = None,
        exp_conditions: dict = None,
434
    ):
435
        self._animal_ids = animal_ids
436
437
        self._arena = arena
        self._arena_dims = arena_dims
438
        self._exp_conditions = exp_conditions
439
        self._path = path
440
441
442
443
444
445
        self._quality = quality
        self._scales = scales
        self._tables = tables
        self._videos = videos
        self.angles = angles
        self.distances = distances
446
447
448
449

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
450
                len(self._videos), len(set(self._exp_conditions.values()))
451
452
            )
        else:
453
            return "deepof analysis of {} videos".format(len(self._videos))
454

455
    def get_coords(
456
457
458
459
460
461
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
462
        align_inplace: bool = False,
463
        propagate_labels: bool = False,
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).
479
480
                - align_inplace (bool): Only valid if align is set. Aligns the vector that goes from the origin to
                the selected body part with the y axis, for all time points.
481
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
482
483
484
485
486

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

487
        tabs = deepof.utils.deepcopy(self._tables)
488

489
490
        if polar:
            for key, tab in tabs.items():
491
                tabs[key] = deepof.utils.tab2polar(tab)
492

493
        if center == "arena":
494
            if self._arena == "circular":
495

496
497
                for i, (key, value) in enumerate(tabs.items()):

498
499
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
500
501
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
502
503
504
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
505
506
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
507
508
509
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
510
511
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
512
513
514
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
515
516
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
517
                        )
518

519
520
521
522
        elif type(center) == str and center != "arena":

            for i, (key, value) in enumerate(tabs.items()):

523
524
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
525
526
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
527
528

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
529
530
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
531
532
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
533
534
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
535
536

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
537
538
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
539

540
                tabs[key] = value.loc[
541
542
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
543

lucas_miranda's avatar
lucas_miranda committed
544
        if speed:
lucas_miranda's avatar
lucas_miranda committed
545
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
546
                vel = deepof.utils.rolling_speed(tab, deriv=speed, center=center)
lucas_miranda's avatar
lucas_miranda committed
547
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
548

549
550
551
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
552
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
553
                ).astype("timedelta64[s]")
554

555
556
        if align:
            assert (
557
                align in list(tabs.values())[0].columns.levels[0]
558
559
560
561
562
563
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
564
565
566
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
567
568
569
570
571
572
573
574
575
576
                tab = tab[columns]
                tabs[key] = tab

                if align_inplace and polar is False:
                    index = tab.columns
                    tab = pd.DataFrame(
                        deepof.utils.align_trajectories(np.array(tab), mode="all")
                    )
                    tab.columns = index
                    tabs[key] = tab
577

578
579
580
581
        if propagate_labels:
            for key, tab in tabs.items():
                tab["pheno"] = self._exp_conditions[key]

582
583
584
585
586
587
588
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
589
            propagate_labels=propagate_labels,
590
591
        )

592
593
594
    def get_distances(
        self, speed: int = 0, length: str = None, propagate_labels: bool = False
    ) -> Table_dict:
595
596
597
598
599
600
601
602
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
603
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
604
605
606
607

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
608

609
        tabs = deepof.utils.deepcopy(self.distances)
lucas_miranda's avatar
lucas_miranda committed
610

lucas_miranda's avatar
lucas_miranda committed
611
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
612
613

            if speed:
lucas_miranda's avatar
lucas_miranda committed
614
                for key, tab in tabs.items():
615
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
616
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
617

618
619
620
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
621
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
622
                    ).astype("timedelta64[s]")
623

624
625
626
627
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

628
            return table_dict(tabs, propagate_labels=propagate_labels, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
629

630
631
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
632
        )  # pragma: no cover
633

634
    def get_angles(
635
636
637
638
639
        self,
        degrees: bool = False,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
640
641
642
643
644
645
646
647
648
649
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
650
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
651
652
653
654

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
655

656
        tabs = deepof.utils.deepcopy(self.angles)
lucas_miranda's avatar
lucas_miranda committed
657

lucas_miranda's avatar
lucas_miranda committed
658
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
659
660
661
662
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
663
                for key, tab in tabs.items():
664
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
665
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
666

667
668
669
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
670
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
671
                    ).astype("timedelta64[s]")
672

673
674
675
676
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

677
            return table_dict(tabs, propagate_labels=propagate_labels, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
678

679
680
681
        raise ValueError(
            "Angles not computed. Read the documentation for more details"
        )  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
682

683
684
685
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

686
        if play:  # pragma: no cover
687
688
689
690
691
692
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
693
694
        """Returns the stored dictionary with experimental conditions per subject"""

695
696
        return self._exp_conditions

697
    def get_quality(self):
698
699
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

700
701
702
703
        return self._quality

    @property
    def get_arenas(self):
704
705
        """Retrieves all available information associated with the arena"""

706
707
        return self._arena, self._arena_dims, self._scales

708
    # noinspection PyDefaultArgument
709
    def rule_based_annotation(
710
        self,
711
        params: Dict = {},
712
713
714
        video_output: bool = False,
        frame_limit: int = np.inf,
        debug: bool = False,
715
    ) -> Table_dict:
716
717
718
        """Annotates coordinates using a simple rule-based pipeline"""

        tag_dict = {}
lucas_miranda's avatar
lucas_miranda committed
719
720
        # noinspection PyTypeChecker
        coords = self.get_coords(center=False)
721
        dists = self.get_distances()
722
        speeds = self.get_coords(speed=1)
723

724
725
        for key in tqdm(self._tables.keys()):

726
            video = [vid for vid in self._videos if key + "DLC" in vid][0]
727
728
729
            tag_dict[key] = deepof.pose_utils.rule_based_tagging(
                list(self._tables.keys()),
                self._videos,
730
                self,
731
                coords,
732
                dists,
733
734
                speeds,
                self._videos.index(video),
735
                arena_type=self._arena,
736
                recog_limit=1,
737
                path=os.path.join(self._path, "Videos"),
738
                hparams=params,
739
            )
740
741
742

        if video_output:  # pragma: no cover

743
744
745
            def output_video(idx):
                """Outputs a single annotated video. Enclosed in a function to enable parallelization"""

746
747
748
749
750
751
                deepof.pose_utils.rule_based_video(
                    self,
                    list(self._tables.keys()),
                    self._videos,
                    list(self._tables.keys()).index(idx),
                    tag_dict[idx],
752
                    debug=debug,
753
754
755
                    frame_limit=frame_limit,
                    recog_limit=1,
                    path=os.path.join(self._path, "Videos"),
756
                    hparams=params,
757
                )
lucas_miranda's avatar
lucas_miranda committed
758
                pbar.update(1)
759

760
761
762
763
764
765
766
767
768
            if type(video_output) == list:
                vid_idxs = video_output
            elif video_output == "all":
                vid_idxs = list(self._tables.keys())
            else:
                raise AttributeError(
                    "Video output must be either 'all' or a list with the names of the videos to render"
                )

lucas_miranda's avatar
lucas_miranda committed
769
            njobs = cpu_count() // 2
lucas_miranda's avatar
lucas_miranda committed
770
            pbar = tqdm(total=len(vid_idxs))
771
772
            with parallel_backend("threading", n_jobs=njobs):
                Parallel()(delayed(output_video)(key) for key in vid_idxs)
lucas_miranda's avatar
lucas_miranda committed
773
            pbar.close()
774

775
776
777
        return table_dict(
            tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
        )
778

779
780
781
    def gmvae_embedding(self):
        pass

782
783

class table_dict(dict):
784
785
786
787
788
789
790
791
792
    """

    Main class for storing a single dataset as a dictionary with individuals as keys and pandas.DataFrames as values.
    Includes methods for generating training and testing datasets for the autoencoders.

    """

    def __init__(
        self,
793
        tabs: Dict,
794
795
796
797
798
        typ: str,
        arena: str = None,
        arena_dims: np.array = None,
        center: str = None,
        polar: bool = None,
799
        propagate_labels: bool = False,
800
    ):
801
802
803
804
805
806
        super().__init__(tabs)
        self._type = typ
        self._center = center
        self._polar = polar
        self._arena = arena
        self._arena_dims = arena_dims
807
        self._propagate_labels = propagate_labels
808

809
    def filter_videos(self, keys: list) -> Table_dict:
810
        """Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
811
        for selecting data coming from videos of a specified condition."""
812
813
814

        assert np.all([k in self.keys() for k in keys]), "Invalid keys selected"

815
816
817
        return table_dict(
            {k: value for k, value in self.items() if k in keys}, self._type
        )
818

lucas_miranda's avatar
lucas_miranda committed
819
    # noinspection PyTypeChecker
820
    def plot_heatmaps(
821
822
823
824
825
826
827
        self,
        bodyparts: list,
        xlim: float = None,
        ylim: float = None,
        save: bool = False,
        i: int = 0,
        dpi: int = 100,
828
829
    ) -> plt.figure:
        """Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
830
831
832

        if self._type != "coords" or self._polar:
            raise NotImplementedError(
lucas_miranda's avatar
lucas_miranda committed
833
834
                "Heatmaps only available for cartesian coordinates. "
                "Set polar to False in get_coordinates and try again"
835
            )  # pragma: no cover
836

837
        if not self._center:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
838
            warnings.warn("Heatmaps look better if you center the data")
839
840
841

        if self._arena == "circular":

842
            heatmaps = deepof.visuals.plot_heatmap(
843
844
845
846
847
848
                list(self.values())[i],
                bodyparts,
                xlim=xlim,
                ylim=ylim,
                save=save,
                dpi=dpi,
849
850
            )

lucas_miranda's avatar
lucas_miranda committed
851
852
            return heatmaps

853
    def get_training_set(
854
855
856
        self,
        test_videos: int = 0,
        encode_labels: bool = True,
857
    ) -> Tuple[np.ndarray, list, Union[np.ndarray, list], list]:
858
859
        """Generates training and test sets as numpy.array objects for model training"""

860
        # Padding of videos with slightly different lengths
lucas_miranda's avatar
lucas_miranda committed
861
        raw_data = np.array([np.array(v) for v in self.values()], dtype=object)
862
        if self._propagate_labels:
863
864
865
866
867
868
869
            concat_raw = np.concatenate(raw_data, axis=0)
            test_index = np.array([], dtype=int)
            for label in set(list(concat_raw[:, -1])):
                label_index = np.random.choice(
                    [i for i in range(len(raw_data)) if raw_data[i][0, -1] == label],
                    test_videos,
                    replace=False,
870
                )
871
                test_index = np.concatenate([test_index, label_index])
872
873
874
875
        else:
            test_index = np.random.choice(
                range(len(raw_data)), test_videos, replace=False
            )
876

877
        y_train, X_test, y_test = [], [], []
878
        if test_videos > 0:
879
880
            X_test = np.concatenate(raw_data[test_index])
            X_train = np.concatenate(np.delete(raw_data, test_index, axis=0))
881
882
883

        else:
            X_train = np.concatenate(list(raw_data))
884

885
        if self._propagate_labels:
886
887
888
889
890
891
            X_train, y_train = X_train[:, :-1], X_train[:, -1]
            try:
                X_test, y_test = X_test[:, :-1], X_test[:, -1]
            except TypeError:
                pass

892
893
894
        if encode_labels:
            le = LabelEncoder()
            y_train = le.fit_transform(y_train)
895
            y_test = le.transform(y_test)
896

897
        return X_train, y_train, X_test, y_test
898

lucas_miranda's avatar
lucas_miranda committed
899
    # noinspection PyTypeChecker,PyGlobalUndefined
900
    def preprocess(
901
        self,
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
        window_size: int = 1,
        window_step: int = 1,
        scale: str = "standard",
        test_videos: int = 0,
        verbose: bool = False,
        conv_filter: bool = None,
        sigma: float = 1.0,
        shift: float = 0.0,
        shuffle: bool = False,
        align: str = False,
    ) -> np.ndarray:
        """

        Main method for preprocessing the loaded dataset. Capable of returning training
        and test sets ready for model training.

            Parameters:
                - window_size (int): Size of the sliding window to pass through the data to generate training instances
                - window_step (int): Step to take when sliding the window. If 1, a true sliding window is used;
                if equal to window_size, the data is split into non-overlapping chunks.
                - scale (str): Data scaling method. Must be one of 'standard' (default; recommended) and 'minmax'.
                - test_videos (int): Number of videos to use when generating the test set.
                If 0, no test set is generated (not recommended).
                - verbose (bool): prints job information if True
                - conv_filter (bool): must be one of None, 'gaussian'. If not None, convolves each instance
                with the specified kernel.
                - sigma (float): usable only if conv_filter is 'gaussian'. Standard deviation of the kernel to use.
                - shift (float): usable only if conv_filter is 'gaussian'. Shift from mean zero of the kernel to use.
                - shuffle (bool): Shuffles the data instances if True. In most use cases, it should be True for training
                and False for prediction.
                - align (bool): If "all", rotates all data instances to align the center -> align (selected before
                when calling get_coords) axis with the y-axis of the cartesian plane. If 'center', rotates all instances
                using the angle of the central frame of the sliding window. This way rotations of the animal are caught
                as well. It doesn't do anything if False.
936
                - propagate_labels (bool): If True, returns a label vector acompaigning each training instance
937
938
939
940
941
942
943
944

            Returns:
                - X_train (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all training videos
                - X_test (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all test videos (if test_videos > 0)

        """
945

lucas_miranda's avatar
lucas_miranda committed
946
        global g
947
        X_train, y_train, X_test, y_test = self.get_training_set(test_videos)
948
949
950
951
952

        if scale:
            if verbose:
                print("Scaling data...")

953
            if scale == "standard":
954
                scaler = StandardScaler()
955
            elif scale == "minmax":
956
                scaler = MinMaxScaler()
957
958
959
            else:
                raise ValueError(
                    "Invalid scaler. Select one of standard, minmax or None"
960
                )  # pragma: no cover
961

lucas_miranda's avatar
lucas_miranda committed
962
963
964
            X_train = scaler.fit_transform(
                X_train.reshape(-1, X_train.shape[-1])
            ).reshape(X_train.shape)
965

966
            if scale == "standard":
967
968
                assert np.allclose(np.nan_to_num(np.mean(X_train), nan=0), 0)
                assert np.allclose(np.nan_to_num(np.std(X_train), nan=1), 1)