data.py 34.7 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

17
from collections import defaultdict
18
from joblib import delayed, Parallel, parallel_backend
19
from typing import Dict, List
20
from pandas_profiling import ProfileReport
21
from psutil import cpu_count
22
23
24
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
from sklearn.manifold import TSNE
25
from sklearn.preprocessing import MinMaxScaler, StandardScaler
26
from tqdm import tqdm
27
28
29
30
31
import deepof.pose_utils
import deepof.utils
import deepof.visuals
import matplotlib.pyplot as plt
import numpy as np
32
import os
33
34
import pandas as pd
import warnings
35

36
37
# DEFINE CUSTOM ANNOTATED TYPES #

38
39
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any)
40
41

# CLASSES FOR PREPROCESSING AND DATA WRANGLING
42

43

44
class project:
lucas_miranda's avatar
lucas_miranda committed
45
46
    """

47
48
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
49
50

    """
51
52

    def __init__(
53
        self,
54
55
        video_format: str = ".mp4",
        table_format: str = ".h5",
56
        path: str = deepof.utils.os.path.join("."),
57
58
59
60
61
        exp_conditions: dict = None,
        arena: str = "circular",
        smooth_alpha: float = 0.1,
        arena_dims: tuple = (1,),
        model: str = "mouse_topview",
62
        animal_ids: List = tuple([""]),
63
    ):
lucas_miranda's avatar
lucas_miranda committed
64

65
        self.path = path
66
67
        self.video_path = self.path + "/Videos/"
        self.table_path = self.path + "/Tables/"
68
        self.videos = sorted(
69
70
71
72
73
            [
                vid
                for vid in deepof.utils.os.listdir(self.video_path)
                if vid.endswith(video_format)
            ]
74
75
        )
        self.tables = sorted(
76
77
78
79
80
            [
                tab
                for tab in deepof.utils.os.listdir(self.table_path)
                if tab.endswith(table_format)
            ]
81
82
83
84
85
86
87
88
        )
        self.exp_conditions = exp_conditions
        self.table_format = table_format
        self.video_format = video_format
        self.arena = arena
        self.arena_dims = arena_dims
        self.smooth_alpha = smooth_alpha
        self.scales = self.get_scale
89
        self.animal_ids = animal_ids
90

91
        self.subset_condition = None
92
        self.distances = "all"
93
94
95
        self.ego = False
        self.angles = True

96
        model_dict = {"mouse_topview": deepof.utils.connect_mouse_topview()}
lucas_miranda's avatar
lucas_miranda committed
97
98
        self.connectivity = model_dict[model]

99
100
    def __str__(self):
        if self.exp_conditions:
101
            return "deepof analysis of {} videos across {} conditions".format(
102
103
104
                len(self.videos), len(self.exp_conditions)
            )
        else:
105
            return "deepof analysis of {} videos".format(len(self.videos))
106

107
108
109
110
111
112
113
114
    @property
    def subset_condition(self):
        """Sets a subset condition for the videos to load. If set,
        only the videos with the included pattern will be loaded"""
        return self._subset_condition

    @property
    def distances(self):
115
        """List. If not 'all', sets the body parts among which the
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        distances will be computed"""
        return self._distances

    @property
    def ego(self):
        """String, name of a body part. If True, computes only the distances
        between the specified body part and the rest"""
        return self._ego

    @property
    def angles(self):
        """Bool. Toggles angle computation. True by default. If turned off,
         enhances performance for big datasets"""
        return self._angles

    @property
    def get_scale(self) -> np.array:
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

            scales = []
            for vid_index, _ in enumerate(self.videos):
                scales.append(
                    list(
141
                        deepof.utils.recognize_arena(
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                            self.videos,
                            vid_index,
                            path=self.video_path,
                            arena_type=self.arena,
                        )[0]
                        * 2
                    )
                    + list(self.arena_dims)
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

157
    def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
158
159
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
160
161
162
163
164
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

165
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
166
            print("Loading trajectories...")
167

lucas_miranda's avatar
lucas_miranda committed
168
169
        tab_dict = {}

170
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
171
172

            tab_dict = {
173
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_hdf(
174
                    deepof.utils.os.path.join(self.table_path, tab), dtype=float
175
176
177
178
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
179

lucas_miranda's avatar
lucas_miranda committed
180
            for tab in self.tables:
181
182
183
                head = pd.read_csv(
                    deepof.utils.os.path.join(self.table_path, tab), nrows=2
                )
lucas_miranda's avatar
lucas_miranda committed
184
                data = pd.read_csv(
185
                    deepof.utils.os.path.join(self.table_path, tab),
lucas_miranda's avatar
lucas_miranda committed
186
187
188
189
190
191
192
193
194
195
196
                    skiprows=2,
                    index_col="coords",
                    dtype={"coords": int},
                ).drop("1", axis=1)
                data.columns = pd.MultiIndex.from_product(
                    [
                        [head.columns[2]],
                        set(list(head.iloc[0])[2:]),
                        ["x", "y", "likelihood"],
                    ],
                    names=["scorer", "bodyparts", "coords"],
197
                )
198
                tab_dict[deepof.utils.re.findall("(.*)DLC", tab)[0]] = data
199
200
201

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
202
        for key, value in tab_dict.items():
203
204
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
lucas_miranda's avatar
lucas_miranda committed
205
206
207
            lik: pd.DataFrame = value.xs(
                "likelihood", level="coords", axis=1, drop_level=True
            )
208

lucas_miranda's avatar
lucas_miranda committed
209
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
lucas_miranda's avatar
lucas_miranda committed
210
            lik_dict[key] = lik
211
212
213

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
214
215
216
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
217
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
218
219
                cols = tab.columns
                smooth = pd.DataFrame(
220
221
222
                    deepof.utils.smooth_mult_trajectory(
                        np.array(tab), alpha=self.smooth_alpha
                    )
223
                )
lucas_miranda's avatar
lucas_miranda committed
224
                smooth.columns = cols
lucas_miranda's avatar
lucas_miranda committed
225
                tab_dict[key] = smooth
226

lucas_miranda's avatar
lucas_miranda committed
227
228
        for key, tab in tab_dict.items():
            tab_dict[key] = tab[tab.columns.levels[0][0]]
229

230
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
231
            for key, value in tab_dict.items():
232
233
234
235
236
237
238
239
240
241
242
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
243
244
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
245
246
247

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
248
                tab_dict[key] = tab
249

lucas_miranda's avatar
lucas_miranda committed
250
        return tab_dict, lik_dict
251

252
253
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
254
255
           If ego is provided, it only returns distances to a specified bodypart"""

lucas_miranda's avatar
lucas_miranda committed
256
257
258
        if verbose:
            print("Computing distances...")

259
        nodes = self.distances
260
        if nodes == "all":
lucas_miranda's avatar
lucas_miranda committed
261
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
262
263

        assert [
lucas_miranda's avatar
lucas_miranda committed
264
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
265
266
267
268
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

269
        distance_dict = {
270
            key: deepof.utils.bpart_distance(tab, scales[i, 1], scales[i, 0],)
lucas_miranda's avatar
lucas_miranda committed
271
            for i, (key, tab) in enumerate(tab_dict.items())
272
        }
273

lucas_miranda's avatar
lucas_miranda committed
274
275
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
276
277
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
278

279
280
281
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
282
283
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
284
285
286

        return distance_dict

287
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
288
289
290
291
292
293
294
295
296
297
298
299
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
300
301
302
        if verbose:
            print("Computing angles...")

303
        cliques = deepof.utils.nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
304
305
306
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
307
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
308
309
310
311

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
312
313
314
                    deepof.utils.angle_trio(
                        np.array(tab[clique]).reshape([3, tab.shape[0], 2])
                    )
lucas_miranda's avatar
lucas_miranda committed
315
316
317
318
319
320
321
322
323
324
325
326
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
327

328
    def run(self, verbose: bool = True) -> Coordinates:
329
330
331
332
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
333
        angles = None
334
335

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
336
            distances = self.get_distances(tables, verbose)
337

lucas_miranda's avatar
lucas_miranda committed
338
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
339
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
340

lucas_miranda's avatar
lucas_miranda committed
341
        if verbose:
342
343
344
            print("Done!")

        return coordinates(
lucas_miranda's avatar
lucas_miranda committed
345
346
347
348
349
350
351
352
353
            tables=tables,
            videos=self.videos,
            arena=self.arena,
            arena_dims=self.arena_dims,
            scales=self.scales,
            quality=quality,
            exp_conditions=self.exp_conditions,
            distances=distances,
            angles=angles,
354
            animal_ids=self.animal_ids,
355
            path=self.path,
356
357
        )

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    @subset_condition.setter
    def subset_condition(self, value):
        self._subset_condition = value

    @distances.setter
    def distances(self, value):
        self._distances = value

    @ego.setter
    def ego(self, value):
        self._ego = value

    @angles.setter
    def angles(self, value):
        self._angles = value

374
375

class coordinates:
376
377
378
379
380
381
382
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

383
    def __init__(
384
        self,
385
386
387
388
389
390
        tables: dict,
        videos: list,
        arena: str,
        arena_dims: np.array,
        scales: np.array,
        quality: dict,
391
        path: str,
392
393
394
        exp_conditions: dict = None,
        distances: dict = None,
        angles: dict = None,
395
        animal_ids: List = tuple([""]),
396
397
398
    ):
        self._tables = tables
        self.distances = distances
lucas_miranda's avatar
lucas_miranda committed
399
        self.angles = angles
400
401
402
403
404
405
        self._videos = videos
        self._exp_conditions = exp_conditions
        self._arena = arena
        self._arena_dims = arena_dims
        self._scales = scales
        self._quality = quality
406
        self._animal_ids = animal_ids
407
        self._path = path
408
409
410
411
412
413
414

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
                len(self._videos), len(self._exp_conditions)
            )
        else:
415
            return "deepof analysis of {} videos".format(len(self._videos))
416

417
    def get_coords(
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

444
        tabs = deepof.utils.deepcopy(self._tables)
445

446
447
        if polar:
            for key, tab in tabs.items():
448
                tabs[key] = deepof.utils.tab2polar(tab)
449

450
        if center == "arena":
451
            if self._arena == "circular":
452

453
454
                for i, (key, value) in enumerate(tabs.items()):

455
456
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
457
458
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
459
460
461
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
462
463
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
464
465
466
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
467
468
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
469
470
471
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
472
473
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
474
                        )
475

476
477
478
479
        elif type(center) == str and center != "arena":

            for i, (key, value) in enumerate(tabs.items()):

480
481
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
482
483
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
484
485

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
486
487
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
488
489
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
490
491
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
492
493

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
494
495
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
496

497
                tabs[key] = value.loc[
498
499
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
500

lucas_miranda's avatar
lucas_miranda committed
501
        if speed:
lucas_miranda's avatar
lucas_miranda committed
502
            for key, tab in tabs.items():
503
                vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, center=center)
lucas_miranda's avatar
lucas_miranda committed
504
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
505

506
507
508
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
509
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
510
                ).astype("timedelta64[s]")
511

512
513
        if align:
            assert (
514
                align in list(tabs.values())[0].columns.levels[0]
515
516
517
518
519
520
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
521
522
523
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
524
525
                tabs[key] = tab[columns]

526
527
528
529
530
531
532
533
534
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
        )

535
536
537
538
539
540
541
542
543
544
545
546
547
    def get_distances(self, speed: int = 0, length: str = None) -> Table_dict:
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
548

549
        tabs = deepof.utils.deepcopy(self.distances)
lucas_miranda's avatar
lucas_miranda committed
550

lucas_miranda's avatar
lucas_miranda committed
551
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
552
553

            if speed:
lucas_miranda's avatar
lucas_miranda committed
554
                for key, tab in tabs.items():
555
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
556
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
557

558
559
560
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
561
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
562
                    ).astype("timedelta64[s]")
563

lucas_miranda's avatar
lucas_miranda committed
564
565
            return table_dict(tabs, typ="dists")

566
567
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
568
        )  # pragma: no cover
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    def get_angles(
        self, degrees: bool = False, speed: int = 0, length: str = None
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
586

587
        tabs = deepof.utils.deepcopy(self.angles)
lucas_miranda's avatar
lucas_miranda committed
588

lucas_miranda's avatar
lucas_miranda committed
589
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
590
591
592
593
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
594
                for key, tab in tabs.items():
595
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
596
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
597

598
599
600
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
601
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
602
                    ).astype("timedelta64[s]")
603

lucas_miranda's avatar
lucas_miranda committed
604
            return table_dict(tabs, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
605

606
607
608
        raise ValueError(
            "Angles not computed. Read the documentation for more details"
        )  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
609

610
611
612
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

613
        if play:  # pragma: no cover
614
615
616
617
618
619
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
620
621
        """Returns the stored dictionary with experimental conditions per subject"""

622
623
        return self._exp_conditions

624
625
626
    def get_quality(self, report: bool = False):
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

627
        if report:  # pragma: no cover
628
629
630
631
632
633
634
635
636
637
            profile = ProfileReport(
                self._quality[report],
                title="Quality Report, {}".format(report),
                html={"style": {"full_width": True}},
            )
            return profile
        return self._quality

    @property
    def get_arenas(self):
638
639
        """Retrieves all available information associated with the arena"""

640
641
        return self._arena, self._arena_dims, self._scales

642
    # noinspection PyDefaultArgument
643
644
645
    def rule_based_annotation(
        self, hparams: Dict = {}, video_output: bool = False, frame_limit: int = np.inf
    ) -> Table_dict:
646
647
648
        """Annotates coordinates using a simple rule-based pipeline"""

        tag_dict = {}
lucas_miranda's avatar
lucas_miranda committed
649
650
        # noinspection PyTypeChecker
        coords = self.get_coords(center=False)
651
652
653
        speeds = self.get_coords(speed=1)
        for key in tqdm(self._tables.keys()):

654
655
            video = [vid for vid in self._videos if key + "DLC" in vid][0]
            print(key, video)
656
657
658
            tag_dict[key] = deepof.pose_utils.rule_based_tagging(
                list(self._tables.keys()),
                self._videos,
659
                self,
660
661
662
                coords,
                speeds,
                self._videos.index(video),
663
                arena_type=self._arena,
664
                recog_limit=1,
665
                path=os.path.join(self._path, "Videos"),
666
667
                hparams=hparams,
            )
668
669
670

        if video_output:  # pragma: no cover

671
672
673
            def output_video(idx):
                """Outputs a single annotated video. Enclosed in a function to enable parallelization"""

674
675
676
677
678
679
680
681
682
683
684
                deepof.pose_utils.rule_based_video(
                    self,
                    list(self._tables.keys()),
                    self._videos,
                    list(self._tables.keys()).index(idx),
                    tag_dict[idx],
                    frame_limit=frame_limit,
                    recog_limit=1,
                    path=os.path.join(self._path, "Videos"),
                    hparams=hparams,
                )
lucas_miranda's avatar
lucas_miranda committed
685
                pbar.update(1)
686

687
688
689
690
691
692
693
694
695
            if type(video_output) == list:
                vid_idxs = video_output
            elif video_output == "all":
                vid_idxs = list(self._tables.keys())
            else:
                raise AttributeError(
                    "Video output must be either 'all' or a list with the names of the videos to render"
                )

lucas_miranda's avatar
lucas_miranda committed
696
697
            njobs = cpu_count(logical=False)
            pbar = tqdm(total=len(vid_idxs))
698
699
            with parallel_backend("threading", n_jobs=njobs):
                Parallel()(delayed(output_video)(key) for key in vid_idxs)
lucas_miranda's avatar
lucas_miranda committed
700
            pbar.close()
701

702
703
704
        return table_dict(
            tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
        )
705

706
707

class table_dict(dict):
708
709
710
711
712
713
714
715
716
    """

    Main class for storing a single dataset as a dictionary with individuals as keys and pandas.DataFrames as values.
    Includes methods for generating training and testing datasets for the autoencoders.

    """

    def __init__(
        self,
717
        tabs: Dict,
718
719
720
721
722
723
        typ: str,
        arena: str = None,
        arena_dims: np.array = None,
        center: str = None,
        polar: bool = None,
    ):
724
725
726
727
728
729
730
        super().__init__(tabs)
        self._type = typ
        self._center = center
        self._polar = polar
        self._arena = arena
        self._arena_dims = arena_dims

731
    def filter(self, keys: list) -> Table_dict:
732
733
734
735
736
        """Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
         for selecting data coming from videos of a specified condition."""

        assert np.all([k in self.keys() for k in keys]), "Invalid keys selected"

737
738
739
        return table_dict(
            {k: value for k, value in self.items() if k in keys}, self._type
        )
740

lucas_miranda's avatar
lucas_miranda committed
741
    # noinspection PyTypeChecker
742
743
744
745
    def plot_heatmaps(
        self, bodyparts: list, save: bool = False, i: int = 0
    ) -> plt.figure:
        """Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
746
747
748

        if self._type != "coords" or self._polar:
            raise NotImplementedError(
lucas_miranda's avatar
lucas_miranda committed
749
750
                "Heatmaps only available for cartesian coordinates. "
                "Set polar to False in get_coordinates and try again"
751
            )  # pragma: no cover
752

753
        if not self._center:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
754
            warnings.warn("Heatmaps look better if you center the data")
755
756
757
758
759
760
761
762
763
764
765
766
767

        if self._arena == "circular":
            x_lim = (
                [-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
                if self._center
                else [0, self._arena_dims[i][0]]
            )
            y_lim = (
                [-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
                if self._center
                else [0, self._arena_dims[i][1]]
            )

768
            heatmaps = deepof.visuals.plot_heatmap(
769
770
771
                list(self.values())[i], bodyparts, xlim=x_lim, ylim=y_lim, save=save,
            )

lucas_miranda's avatar
lucas_miranda committed
772
773
            return heatmaps

774
775
776
    def get_training_set(
        self, test_videos: int = 0
    ) -> deepof.utils.Tuple[np.ndarray, np.ndarray]:
777
778
        """Generates training and test sets as numpy.array objects for model training"""

779
        rmax = max([i.shape[0] for i in self.values()])
780
        raw_data = np.array(
781
782
            [np.pad(v, ((0, rmax - v.shape[0]), (0, 0))) for v in self.values()]
        )
783
784
785
786
787
788
789
790
791
        test_index = np.random.choice(range(len(raw_data)), test_videos, replace=False)

        X_test = []
        if test_videos > 0:
            X_test = np.concatenate(list(raw_data[test_index]))
            X_train = np.concatenate(list(np.delete(raw_data, test_index, axis=0)))

        else:
            X_train = np.concatenate(list(raw_data))
792

793
        return X_train, X_test
794

lucas_miranda's avatar
lucas_miranda committed
795
    # noinspection PyTypeChecker,PyGlobalUndefined
796
    def preprocess(
797
        self,
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        window_size: int = 1,
        window_step: int = 1,
        scale: str = "standard",
        test_videos: int = 0,
        verbose: bool = False,
        conv_filter: bool = None,
        sigma: float = 1.0,
        shift: float = 0.0,
        shuffle: bool = False,
        align: str = False,
    ) -> np.ndarray:
        """

        Main method for preprocessing the loaded dataset. Capable of returning training
        and test sets ready for model training.

            Parameters:
                - window_size (int): Size of the sliding window to pass through the data to generate training instances
                - window_step (int): Step to take when sliding the window. If 1, a true sliding window is used;
                if equal to window_size, the data is split into non-overlapping chunks.
                - scale (str): Data scaling method. Must be one of 'standard' (default; recommended) and 'minmax'.
                - test_videos (int): Number of videos to use when generating the test set.
                If 0, no test set is generated (not recommended).
                - verbose (bool): prints job information if True
                - conv_filter (bool): must be one of None, 'gaussian'. If not None, convolves each instance
                with the specified kernel.
                - sigma (float): usable only if conv_filter is 'gaussian'. Standard deviation of the kernel to use.
                - shift (float): usable only if conv_filter is 'gaussian'. Shift from mean zero of the kernel to use.
                - shuffle (bool): Shuffles the data instances if True. In most use cases, it should be True for training
                and False for prediction.
                - align (bool): If "all", rotates all data instances to align the center -> align (selected before
                when calling get_coords) axis with the y-axis of the cartesian plane. If 'center', rotates all instances
                using the angle of the central frame of the sliding window. This way rotations of the animal are caught
                as well. It doesn't do anything if False.

            Returns:
                - X_train (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all training videos
                - X_test (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all test videos (if test_videos > 0)

        """
840

lucas_miranda's avatar
lucas_miranda committed
841
        global g
842
        X_train, X_test = self.get_training_set(test_videos)
843
844
845
846
847

        if scale:
            if verbose:
                print("Scaling data...")

848
            if scale == "standard":
849
                scaler = StandardScaler()
850
            elif scale == "minmax":
851
                scaler = MinMaxScaler()
852
853
854
            else:
                raise ValueError(
                    "Invalid scaler. Select one of standard, minmax or None"
855
                )  # pragma: no cover
856

857
858
859
860
            X_train = scaler.fit_transform(
                X_train.reshape(-1, X_train.shape[-1])
            ).reshape(X_train.shape)

861
            if scale == "standard":
862
                assert np.allclose(np.mean(X_train), 0)
lucas_miranda's avatar
lucas_miranda committed
863
                assert np.allclose(np.std(X_train), 1)
864

865
            if test_videos:
866
867
868
869
870
871
872
                X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
                    X_test.shape
                )

            if verbose:
                print("Done!")

873
        if align == "all":
874
            X_train = deepof.utils.align_trajectories(X_train, align)
875

876
        X_train = deepof.utils.rolling_window(X_train, window_size, window_step)
877

878
        if align == "center":
879
            X_train = deepof.utils.align_trajectories(X_train, align)
880

lucas_miranda's avatar
lucas_miranda committed
881
        if conv_filter == "gaussian":
882
            r = range(-int(window_size / 2), int(window_size / 2) + 1)
883
            r = [i - shift for i in r]
884
885
886
887
888
889
890
891
892
            g = np.array(
                [
                    1
                    / (sigma * np.sqrt(2 * np.pi))
                    * np.exp(-float(x) ** 2 / (2 * sigma ** 2))
                    for x in r
                ]
            )
            g /= np.max(g)
lucas_miranda's avatar
lucas_miranda committed
893
            X_train = X_train * g.reshape([1, window_size, 1])
894

895
        if test_videos:
896
897

            if align == "all":
898
                X_test = deepof.utils.align_trajectories(X_test, align)
899

900
            X_test = deepof.utils.rolling_window(X_test, window_size, window_step)
901

902
            if align == "center":
903
                X_test = deepof.utils.align_trajectories(X_test, align)
904

lucas_miranda's avatar
lucas_miranda committed
905
            if conv_filter == "gaussian":
lucas_miranda's avatar
lucas_miranda committed
906
                X_test = X_test * g.reshape([1, window_size, 1])
907

908
            if shuffle:
909
910
911
912
913
914
                X_train = X_train[
                    np.random.choice(X_train.shape[0], X_train.shape[0], replace=False)
                ]
                X_test = X_test[
                    np.random.choice(X_test.shape[0], X_test.shape[0], replace=False)
                ]
915

916
917
            return X_train, X_test

918
        if shuffle:
919
920
921
            X_train = X_train[
                np.random.choice(X_train.shape[0], X_train.shape[0], replace=False)
            ]
922

923
924
        return X_train

925
926
    def random_projection(
        self, n_components: int = None, sample: int = 1000
927
    ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
928
929
930
        """Returns a training set generated from the 2D original data (time x features) and a random projection
        to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
        performance or visualization reasons"""
931

lucas_miranda's avatar
lucas_miranda committed
932
        X = self.get_training_set()[0]
933
934
935
936
937
938
939
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

        rproj = random_projection.GaussianRandomProjection(n_components=n_components)
        X = rproj.fit_transform(X)

        return X, rproj

940
941
    def pca(
        self, n_components: int = None, sample: int = 1000, kernel: str = "linear"
942
    ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
943
944
945
        """Returns a training set generated from the 2D original data (time x features) and a PCA projection
        to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
        performance or visualization reasons"""
946

lucas_miranda's avatar
lucas_miranda committed
947
        X = self.get_training_set()[0]
948
949
950
951
952
953
954
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

        pca = KernelPCA(n_components=n_components, kernel=kernel)
        X = pca.fit_transform(X)

        return X, pca

955
956
    def tsne(
        self, n_components: int = None, sample: int = 1000, perplexity: int = 30
957
    ) -> deepof.utils.Tuple[deepof.utils.Any, deepof.utils.Any]:
958
959
960
        """Returns a training set generated from the 2D original data (time x features) and a PCA projection
        to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
        performance or visualization reasons"""
961

lucas_miranda's avatar
lucas_miranda committed
962
        X = self.get_training_set()[0]
963
964
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

965
        tsne = TSNE(n_components=n_components, perplexity=perplexity)
966
967
968
        X = tsne.fit_transform(X)

        return X, tsne
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991


def merge_tables(*args):
    """

    Takes a number of table_dict objects and merges them
    Returns a table_dict object of type 'merged'

    """
    merged_dict = {key: [] for key in args[0].keys()}
    for tabdict in args:
        for key, val in tabdict.items():
            merged_dict[key].append(val)

    merged_tables = table_dict(
        {
            key: pd.concat(val, axis=1, ignore_index=True)
            for key, val in merged_dict.items()
        },
        typ="merged",
    )

    return merged_tables
992
993
994
995
996


# TODO:
#   - Generate ragged training array using a metric (acceleration, maybe?)
#   - Use something like Dynamic Time Warping to put all instances in the same length
997
998
999
#   - with the current implementation, preprocess can't fully work on merged table_dict instances.
#   While some operations (mainly alignment) should be carried out before merging, others require
#   the whole dataset to function properly.