data.py 41.7 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

17
from collections import defaultdict
18
from joblib import delayed, Parallel, parallel_backend
19
from typing import Dict, List, Tuple, Union
lucas_miranda's avatar
lucas_miranda committed
20
from multiprocessing import cpu_count
21
22
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
23
from sklearn.impute import SimpleImputer
24
from sklearn.manifold import TSNE
25
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
26
from tqdm import tqdm
27
import deepof.models
28
29
30
31
32
import deepof.pose_utils
import deepof.utils
import deepof.visuals
import matplotlib.pyplot as plt
import numpy as np
33
import os
34
35
import pandas as pd
import warnings
36

37
38
# DEFINE CUSTOM ANNOTATED TYPES #

39
40
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any)
41
42

# CLASSES FOR PREPROCESSING AND DATA WRANGLING
43

44

45
class project:
lucas_miranda's avatar
lucas_miranda committed
46
47
    """

48
49
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
50
51

    """
52
53

    def __init__(
54
        self,
55
        animal_ids: List = tuple([""]),
56
57
        arena: str = "circular",
        arena_dims: tuple = (1,),
58
59
        exclude_bodyparts: List = tuple([""]),
        exp_conditions: dict = None,
60
        interpolate_outliers: bool = True,
61
        interpolation_limit: int = 5,
lucas_miranda's avatar
lucas_miranda committed
62
        interpolation_std: int = 5,
63
        likelihood_tol: float = 0.75,
64
        model: str = "mouse_topview",
65
66
        path: str = deepof.utils.os.path.join("."),
        smooth_alpha: float = 0.99,
67
        table_format: str = "autodetect",
68
        video_format: str = ".mp4",
69
    ):
lucas_miranda's avatar
lucas_miranda committed
70

71
        self.path = path
72
73
        self.video_path = os.path.join(self.path, "Videos")
        self.table_path = os.path.join(self.path, "Tables")
74

75
        self.table_format = table_format
76
        if self.table_format == "autodetect":
77
            ex = [i for i in os.listdir(self.table_path) if not i.startswith(".")][0]
78
79
80
81
            if ".h5" in ex:
                self.table_format = ".h5"
            elif ".csv" in ex:
                self.table_format = ".csv"
82

83
        self.videos = sorted(
84
85
86
            [
                vid
                for vid in deepof.utils.os.listdir(self.video_path)
87
                if vid.endswith(video_format) and not vid.startswith(".")
88
            ]
89
90
        )
        self.tables = sorted(
91
92
93
            [
                tab
                for tab in deepof.utils.os.listdir(self.table_path)
94
                if tab.endswith(self.table_format) and not tab.startswith(".")
95
            ]
96
        )
97
        self.angles = True
98
        self.animal_ids = animal_ids
99
100
        self.arena = arena
        self.arena_dims = arena_dims
101
102
        self.distances = "all"
        self.ego = False
103
        self.exp_conditions = exp_conditions
104
105
106
107
        self.interpolate_outliers = interpolate_outliers
        self.interpolation_limit = interpolation_limit
        self.interpolation_std = interpolation_std
        self.likelihood_tolerance = likelihood_tol
108
        self.scales = self.get_scale
109
110
        self.smooth_alpha = smooth_alpha
        self.subset_condition = None
111
        self.video_format = video_format
112

lucas_miranda's avatar
lucas_miranda committed
113
114
115
        model_dict = {
            "mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
        }
lucas_miranda's avatar
lucas_miranda committed
116
        self.connectivity = model_dict[model]
117
118
119
120
        self.exclude_bodyparts = exclude_bodyparts
        if self.exclude_bodyparts != tuple([""]):
            for bp in exclude_bodyparts:
                self.connectivity.remove_node(bp)
lucas_miranda's avatar
lucas_miranda committed
121

122
123
    def __str__(self):
        if self.exp_conditions:
124
            return "deepof analysis of {} videos across {} conditions".format(
125
                len(self.videos), len(set(self.exp_conditions.values()))
126
127
            )
        else:
128
            return "deepof analysis of {} videos".format(len(self.videos))
129

130
131
132
133
134
135
136
137
    @property
    def subset_condition(self):
        """Sets a subset condition for the videos to load. If set,
        only the videos with the included pattern will be loaded"""
        return self._subset_condition

    @property
    def distances(self):
138
        """List. If not 'all', sets the body parts among which the
139
140
141
142
143
144
145
146
147
148
149
150
        distances will be computed"""
        return self._distances

    @property
    def ego(self):
        """String, name of a body part. If True, computes only the distances
        between the specified body part and the rest"""
        return self._ego

    @property
    def angles(self):
        """Bool. Toggles angle computation. True by default. If turned off,
151
        enhances performance for big datasets"""
152
153
154
155
156
157
158
159
160
161
        return self._angles

    @property
    def get_scale(self) -> np.array:
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

            scales = []
            for vid_index, _ in enumerate(self.videos):
162
163
164
165
166
167
168
169

                ellipse = deepof.utils.recognize_arena(
                    self.videos,
                    vid_index,
                    path=self.video_path,
                    arena_type=self.arena,
                )[0]

170
                scales.append(
171
                    list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1]]) * 2)
172
173
174
175
176
177
178
179
                    + list(self.arena_dims)
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

180
    def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
181
182
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
183
184
185
186
187
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

188
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
189
            print("Loading trajectories...")
190

lucas_miranda's avatar
lucas_miranda committed
191
192
        tab_dict = {}

193
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
194
195

            tab_dict = {
196
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_hdf(
197
                    deepof.utils.os.path.join(self.table_path, tab), dtype=float
198
199
200
201
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
202

203
204
            tab_dict = {
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_csv(
205
                    deepof.utils.os.path.join(self.table_path, tab),
206
207
208
                    header=[0, 1, 2],
                    index_col=0,
                    dtype=float,
209
                )
210
211
                for tab in self.tables
            }
212
213
214

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
215
        for key, value in tab_dict.items():
216
217
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
218
            lik = value.xs("likelihood", level="coords", axis=1, drop_level=True)
219

lucas_miranda's avatar
lucas_miranda committed
220
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
221
            lik_dict[key] = lik.droplevel("scorer", axis=1)
222
223
224

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
225
226
227
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
228
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
229
230
                cols = tab.columns
                smooth = pd.DataFrame(
231
232
233
                    deepof.utils.smooth_mult_trajectory(
                        np.array(tab), alpha=self.smooth_alpha
                    )
234
                )
lucas_miranda's avatar
lucas_miranda committed
235
                smooth.columns = cols
236
                tab_dict[key] = smooth.iloc[1:, :].reset_index(drop=True)
237

lucas_miranda's avatar
lucas_miranda committed
238
239
        for key, tab in tab_dict.items():
            tab_dict[key] = tab[tab.columns.levels[0][0]]
240

241
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
242
            for key, value in tab_dict.items():
243
244
245
246
247
248
249
250
251
252
253
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
254
255
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
256
257
258

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
259
                tab_dict[key] = tab
260

261
262
263
264
265
266
267
268
269
270
        if self.exclude_bodyparts != tuple([""]):

            for k, value in tab_dict.items():
                temp = value.drop(self.exclude_bodyparts, axis=1, level="bodyparts")
                temp.sort_index(axis=1, inplace=True)
                temp.columns = pd.MultiIndex.from_product(
                    [sorted(list(set([i[j] for i in temp.columns]))) for j in range(2)]
                )
                tab_dict[k] = temp.sort_index(axis=1)

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        if self.interpolate_outliers:

            if verbose:
                print("Interpolating outliers...")

            for k, value in tab_dict.items():
                tab_dict[k] = deepof.utils.interpolate_outliers(
                    value,
                    lik_dict[k],
                    likelihood_tolerance=self.likelihood_tolerance,
                    mode="or",
                    limit=self.interpolation_limit,
                    n_std=self.interpolation_std,
                )

lucas_miranda's avatar
lucas_miranda committed
286
        return tab_dict, lik_dict
287

288
289
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
290
        If ego is provided, it only returns distances to a specified bodypart"""
291

lucas_miranda's avatar
lucas_miranda committed
292
293
294
        if verbose:
            print("Computing distances...")

295
        nodes = self.distances
296
        if nodes == "all":
lucas_miranda's avatar
lucas_miranda committed
297
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
298
299

        assert [
lucas_miranda's avatar
lucas_miranda committed
300
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
301
302
303
304
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

305
        distance_dict = {
306
307
308
309
310
            key: deepof.utils.bpart_distance(
                tab,
                scales[i, 1],
                scales[i, 0],
            )
lucas_miranda's avatar
lucas_miranda committed
311
            for i, (key, tab) in enumerate(tab_dict.items())
312
        }
313

lucas_miranda's avatar
lucas_miranda committed
314
315
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
316
317
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
318

319
320
321
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
322
323
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
324
325
326

        return distance_dict

327
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
328
329
330
331
332
333
334
335
336
337
338
339
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
340
341
342
        if verbose:
            print("Computing angles...")

343
        cliques = deepof.utils.nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
344
345
346
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
347
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
348
349
350
351

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
352
353
354
                    deepof.utils.angle_trio(
                        np.array(tab[clique]).reshape([3, tab.shape[0], 2])
                    )
lucas_miranda's avatar
lucas_miranda committed
355
356
357
358
359
360
361
362
363
364
365
366
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
367

368
    def run(self, verbose: bool = True) -> Coordinates:
369
370
371
372
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
373
        angles = None
374
375

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
376
            distances = self.get_distances(tables, verbose)
377

lucas_miranda's avatar
lucas_miranda committed
378
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
379
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
380

lucas_miranda's avatar
lucas_miranda committed
381
        if verbose:
382
383
384
            print("Done!")

        return coordinates(
385
386
            angles=angles,
            animal_ids=self.animal_ids,
lucas_miranda's avatar
lucas_miranda committed
387
388
389
            arena=self.arena,
            arena_dims=self.arena_dims,
            distances=distances,
390
            exp_conditions=self.exp_conditions,
391
            path=self.path,
392
393
394
395
            quality=quality,
            scales=self.scales,
            tables=tables,
            videos=self.videos,
396
397
        )

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    @subset_condition.setter
    def subset_condition(self, value):
        self._subset_condition = value

    @distances.setter
    def distances(self, value):
        self._distances = value

    @ego.setter
    def ego(self, value):
        self._ego = value

    @angles.setter
    def angles(self, value):
        self._angles = value

414
415

class coordinates:
416
417
418
419
420
421
422
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

423
    def __init__(
424
        self,
425
426
        arena: str,
        arena_dims: np.array,
427
        path: str,
428
429
430
431
        quality: dict,
        scales: np.array,
        tables: dict,
        videos: list,
432
        angles: dict = None,
433
        animal_ids: List = tuple([""]),
434
435
        distances: dict = None,
        exp_conditions: dict = None,
436
    ):
437
        self._animal_ids = animal_ids
438
439
        self._arena = arena
        self._arena_dims = arena_dims
440
        self._exp_conditions = exp_conditions
441
        self._path = path
442
443
444
445
446
447
        self._quality = quality
        self._scales = scales
        self._tables = tables
        self._videos = videos
        self.angles = angles
        self.distances = distances
448
449
450
451

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
452
                len(self._videos), len(set(self._exp_conditions.values()))
453
454
            )
        else:
455
            return "deepof analysis of {} videos".format(len(self._videos))
456

457
    def get_coords(
458
459
460
461
462
463
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
464
        align_inplace: bool = False,
465
        propagate_labels: bool = False,
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).
481
482
                - align_inplace (bool): Only valid if align is set. Aligns the vector that goes from the origin to
                the selected body part with the y axis, for all time points.
483
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
484
485
486
487
488

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

489
        tabs = deepof.utils.deepcopy(self._tables)
490

491
492
        if polar:
            for key, tab in tabs.items():
493
                tabs[key] = deepof.utils.tab2polar(tab)
494

495
        if center == "arena":
496
            if self._arena == "circular":
497

498
499
                for i, (key, value) in enumerate(tabs.items()):

500
501
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
502
503
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
504
505
506
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
507
508
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
509
510
511
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
512
513
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
514
515
516
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
517
518
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
519
                        )
520

521
522
523
524
        elif type(center) == str and center != "arena":

            for i, (key, value) in enumerate(tabs.items()):

525
526
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
527
528
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
529
530

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
531
532
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
533
534
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
535
536
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
537
538

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
539
540
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
541

542
                tabs[key] = value.loc[
543
544
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
545

lucas_miranda's avatar
lucas_miranda committed
546
        if speed:
lucas_miranda's avatar
lucas_miranda committed
547
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
548
                vel = deepof.utils.rolling_speed(tab, deriv=speed, center=center)
lucas_miranda's avatar
lucas_miranda committed
549
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
550

551
552
553
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
554
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
555
                ).astype("timedelta64[s]")
556

557
558
        if align:
            assert (
559
                align in list(tabs.values())[0].columns.levels[0]
560
561
562
563
564
565
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
566
567
568
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
569
570
571
572
573
574
575
576
577
578
                tab = tab[columns]
                tabs[key] = tab

                if align_inplace and polar is False:
                    index = tab.columns
                    tab = pd.DataFrame(
                        deepof.utils.align_trajectories(np.array(tab), mode="all")
                    )
                    tab.columns = index
                    tabs[key] = tab
579

580
581
582
583
        if propagate_labels:
            for key, tab in tabs.items():
                tab["pheno"] = self._exp_conditions[key]

584
585
586
587
588
589
590
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
591
            propagate_labels=propagate_labels,
592
593
        )

594
595
596
    def get_distances(
        self, speed: int = 0, length: str = None, propagate_labels: bool = False
    ) -> Table_dict:
597
598
599
600
601
602
603
604
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
605
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
606
607
608
609

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
610

611
        tabs = deepof.utils.deepcopy(self.distances)
lucas_miranda's avatar
lucas_miranda committed
612

lucas_miranda's avatar
lucas_miranda committed
613
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
614
615

            if speed:
lucas_miranda's avatar
lucas_miranda committed
616
                for key, tab in tabs.items():
617
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
618
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
619

620
621
622
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
623
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
624
                    ).astype("timedelta64[s]")
625

626
627
628
629
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

630
            return table_dict(tabs, propagate_labels=propagate_labels, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
631

632
633
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
634
        )  # pragma: no cover
635

636
    def get_angles(
637
638
639
640
641
        self,
        degrees: bool = False,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
642
643
644
645
646
647
648
649
650
651
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
652
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
653
654
655
656

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
657

658
        tabs = deepof.utils.deepcopy(self.angles)
lucas_miranda's avatar
lucas_miranda committed
659

lucas_miranda's avatar
lucas_miranda committed
660
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
661
662
663
664
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
665
                for key, tab in tabs.items():
666
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
667
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
668

669
670
671
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
672
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
673
                    ).astype("timedelta64[s]")
674

675
676
677
678
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

679
            return table_dict(tabs, propagate_labels=propagate_labels, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
680

681
682
683
        raise ValueError(
            "Angles not computed. Read the documentation for more details"
        )  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
684

685
686
687
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

688
        if play:  # pragma: no cover
689
690
691
692
693
694
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
695
696
        """Returns the stored dictionary with experimental conditions per subject"""

697
698
        return self._exp_conditions

699
    def get_quality(self):
700
701
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

702
703
704
705
        return self._quality

    @property
    def get_arenas(self):
706
707
        """Retrieves all available information associated with the arena"""

708
709
        return self._arena, self._arena_dims, self._scales

710
    # noinspection PyDefaultArgument
711
    def rule_based_annotation(
712
        self,
713
        params: Dict = {},
714
715
716
        video_output: bool = False,
        frame_limit: int = np.inf,
        debug: bool = False,
717
    ) -> Table_dict:
718
719
720
        """Annotates coordinates using a simple rule-based pipeline"""

        tag_dict = {}
lucas_miranda's avatar
lucas_miranda committed
721
722
        # noinspection PyTypeChecker
        coords = self.get_coords(center=False)
723
        dists = self.get_distances()
724
        speeds = self.get_coords(speed=1)
725

726
727
        for key in tqdm(self._tables.keys()):

728
            video = [vid for vid in self._videos if key + "DLC" in vid][0]
729
730
731
            tag_dict[key] = deepof.pose_utils.rule_based_tagging(
                list(self._tables.keys()),
                self._videos,
732
                self,
733
                coords,
734
                dists,
735
736
                speeds,
                self._videos.index(video),
737
                arena_type=self._arena,
738
                recog_limit=1,
739
                path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
740
                params=params,
741
            )
742
743
744

        if video_output:  # pragma: no cover

745
746
747
            def output_video(idx):
                """Outputs a single annotated video. Enclosed in a function to enable parallelization"""

748
749
750
751
752
753
                deepof.pose_utils.rule_based_video(
                    self,
                    list(self._tables.keys()),
                    self._videos,
                    list(self._tables.keys()).index(idx),
                    tag_dict[idx],
754
                    debug=debug,
755
756
757
                    frame_limit=frame_limit,
                    recog_limit=1,
                    path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
758
                    params=params,
759
                )
lucas_miranda's avatar
lucas_miranda committed
760
                pbar.update(1)
761

762
763
764
765
766
767
768
769
770
            if type(video_output) == list:
                vid_idxs = video_output
            elif video_output == "all":
                vid_idxs = list(self._tables.keys())
            else:
                raise AttributeError(
                    "Video output must be either 'all' or a list with the names of the videos to render"
                )

lucas_miranda's avatar
lucas_miranda committed
771
            njobs = cpu_count() // 2
lucas_miranda's avatar
lucas_miranda committed
772
            pbar = tqdm(total=len(vid_idxs))
773
774
            with parallel_backend("threading", n_jobs=njobs):
                Parallel()(delayed(output_video)(key) for key in vid_idxs)
lucas_miranda's avatar
lucas_miranda committed
775
            pbar.close()
776

777
778
779
        return table_dict(
            tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
        )
780

781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
    @staticmethod
    def deep_unsupervised_embedding(
        preprocessed_object: np.array,
        encoding_size: int,
        batch_size: int = 256,
        cp: bool = True,
        hparams: dict = None,
        kl_warmup: int = 0,
        loss: str = "ELBO",
        mmd_warmup: int = 0,
        montecarlo_kl: int = 10,
        n_components: int = 25,
        outpath: str = ".",
        phenotype_class: float = 0,
        predictor: float = 0,
        pretrained: str = False,
        variational: bool = True,
    ):

        # Load all
        X_train, y_train, X_val, y_val = preprocessed_object

        if not variational:
            encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
                ({} if hparams is None else hparams)
            ).build(X_train.shape)
            return_list = (encoder, decoder, ae)

        else:
            (
                encoder,
                generator,
                grouper,
                ae,
                kl_warmup_callback,
                mmd_warmup_callback,
            ) = deepof.models.SEQ_2_SEQ_GMVAE(
                architecture_hparams=({} if hparams is None else hparams),
                batch_size=batch_size,
                compile_model=True,
                encoding=encoding_size,
                kl_warmup_epochs=kl_warmup,
                loss=loss,
                mmd_warmup_epochs=mmd_warmup,
                montecarlo_kl=montecarlo_kl,
                neuron_control=False,
                number_of_components=n_components,
                overlap_loss=False,
                phenotype_prediction=phenotype_class,
                predictor=predictor,
            ).build(
                X_train.shape
            )
            return_list = (encoder, generator, grouper, ae)

        if pretrained:
            ae.load_weights(pretrained)
            return return_list
        # returns a trained tensorflow model

        return ae
842

843
844

class table_dict(dict):
845
846
847
848
849
850
851
852
853
    """

    Main class for storing a single dataset as a dictionary with individuals as keys and pandas.DataFrames as values.
    Includes methods for generating training and testing datasets for the autoencoders.

    """

    def __init__(
        self,
854
        tabs: Dict,
855
856
857
858
859
        typ: str,
        arena: str = None,
        arena_dims: np.array = None,
        center: str = None,
        polar: bool = None,
860
        propagate_labels: bool = False,
861
    ):
862
863
864
865
866
867
        super().__init__(tabs)
        self._type = typ
        self._center = center
        self._polar = polar
        self._arena = arena
        self._arena_dims = arena_dims
868
        self._propagate_labels = propagate_labels
869

870
    def filter_videos(self, keys: list) -> Table_dict:
871
        """Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
872
        for selecting data coming from videos of a specified condition."""
873
874
875

        assert np.all([k in self.keys() for k in keys]), "Invalid keys selected"

876
877
878
        return table_dict(
            {k: value for k, value in self.items() if k in keys}, self._type
        )
879

lucas_miranda's avatar
lucas_miranda committed
880
    # noinspection PyTypeChecker
881
    def plot_heatmaps(
882
883
884
885
886
887
888
        self,
        bodyparts: list,
        xlim: float = None,
        ylim: float = None,
        save: bool = False,
        i: int = 0,
        dpi: int = 100,
889
890
    ) -> plt.figure:
        """Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
891
892
893

        if self._type != "coords" or self._polar:
            raise NotImplementedError(
lucas_miranda's avatar
lucas_miranda committed
894
895
                "Heatmaps only available for cartesian coordinates. "
                "Set polar to False in get_coordinates and try again"
896
            )  # pragma: no cover
897

898
        if not self._center:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
899
            warnings.warn("Heatmaps look better if you center the data")
900
901
902

        if self._arena == "circular":

903
            heatmaps = deepof.visuals.plot_heatmap(
904
905
906
907
908
909
                list(self.values())[i],
                bodyparts,
                xlim=xlim,
                ylim=ylim,
                save=save,
                dpi=dpi,
910
911
            )

lucas_miranda's avatar
lucas_miranda committed
912
913
            return heatmaps

914
    def get_training_set(
915
916
917
        self,
        test_videos: int = 0,
        encode_labels: bool = True,
918
    ) -> Tuple[np.ndarray, list, Union[np.ndarray, list], list]:
919
920
        """Generates training and test sets as numpy.array objects for model training"""

921
        # Padding of videos with slightly different lengths
lucas_miranda's avatar
lucas_miranda committed
922
        raw_data = np.array([np.array(v) for v in self.values()], dtype=object)
923
        if self._propagate_labels:
924
925
926
927
928
929
930
            concat_raw = np.concatenate(raw_data, axis=0)
            test_index = np.array([], dtype=int)
            for label in set(list(concat_raw[:, -1])):
                label_index = np.random.choice(
                    [i for i in range(len(raw_data)) if raw_data[i][0, -1] == label],
                    test_videos,
                    replace=False,
931
                )
932
                test_index = np.concatenate([test_index, label_index])
933
934
935
936
        else:
            test_index = np.random.choice(
                range(len(raw_data)), test_videos, replace=False
            )
937

938
        y_train, X_test, y_test = [], [], []
939
        if test_videos > 0:
940
941
            X_test = np.concatenate(raw_data[test_index])
            X_train = np.concatenate(np.delete(raw_data, test_index, axis=0))
942
943
944

        else:
            X_train = np.concatenate(list(raw_data))
945

946
        if self._propagate_labels:
947
948
949
950
951
952
            X_train, y_train = X_train[:, :-1], X_train[:, -1]
            try:
                X_test, y_test = X_test[:, :-1], X_test[:, -1]
            except TypeError:
                pass

953
954
955
        if encode_labels:
            le = LabelEncoder()
            y_train = le.fit_transform(y_train)
956
            y_test = le.transform(y_test)
957

958
        return X_train, y_train, X_test, y_test
959

lucas_miranda's avatar
lucas_miranda committed
960
    # noinspection PyTypeChecker,PyGlobalUndefined
961
    def preprocess(
962
        self,
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
        window_size: int = 1,
        window_step: int = 1,
        scale: str = "standard",
        test_videos: int = 0,
        verbose: bool = False,
        conv_filter: bool = None,
        sigma: float = 1.0,
        shift: float = 0.0,
        shuffle: bool = False,
        align: str = False,
    ) -> np.ndarray:
        """

        Main method for preprocessing the loaded dataset. Capable of returning training
        and test sets ready for model training.

            Parameters:
                - window_size (int): Size of the sliding window to pass through the data to generate training instances
                - window_step (int): Step to take when sliding the window. If 1, a true sliding window is used;
                if equal to window_size, the data is split into non-overlapping chunks.
                - scale (str): Data scaling method. Must be one of 'standard' (default; recommended) and 'minmax'.
                - test_videos (int): Number of videos to use when generating the test set.
                If 0, no test set is generated (not recommended).
                - verbose (bool): prints job information if True
                - conv_filter (bool): must be one of None, 'gaussian'. If not None, convolves each instance
                with the specified kernel.
                - sigma (float): usable only if conv_filter is 'gaussian'. Standard deviation of the kernel to use.
                - shift (float): usable only if conv_filter is 'gaussian'. Shift from mean zero of the kernel to use.
                - shuffle (bool): Shuffles the data instances if True. In most use cases, it should be True for training
                and False for prediction.
                - align (bool): If "all", rotates all data instances to align the center -> align (selected before
                when calling get_coords) axis with the y-axis of the cartesian plane. If 'center', rotates all instances
                using the angle of the central frame of the sliding window. This way rotations of the animal are caught
                as well. It doesn't do anything if False.
997
                - propagate_labels (bool): If True, returns a label vector acompaigning each training instance
998
999
1000
1001
1002
1003
1004
1005

            Returns:
                - X_train (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all training videos
                - X_test (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all test videos (if test_videos > 0)

        """
1006

lucas_miranda's avatar
lucas_miranda committed
1007
        global g
1008
        X_train, y_train, X_test, y_test = self.get_training_set(test_videos)
1009
1010
1011
1012
1013

        if scale:
            if verbose:
                print("Scaling data...")