data.py 44.4 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

17
from collections import defaultdict
18
from joblib import delayed, Parallel, parallel_backend
19
from typing import Dict, List, Tuple, Union
lucas_miranda's avatar
lucas_miranda committed
20
from multiprocessing import cpu_count
21
22
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
23
24
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
25
from sklearn.manifold import TSNE
26
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
27
from tqdm import tqdm
28
import deepof.models
29
30
31
import deepof.pose_utils
import deepof.utils
import deepof.visuals
32
import deepof.train_utils
33
34
import matplotlib.pyplot as plt
import numpy as np
35
import os
36
37
import pandas as pd
import warnings
38

39
40
# DEFINE CUSTOM ANNOTATED TYPES #

41
42
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any)
43
44

# CLASSES FOR PREPROCESSING AND DATA WRANGLING
45

46

47
class project:
lucas_miranda's avatar
lucas_miranda committed
48
49
    """

50
51
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
52
53

    """
54
55

    def __init__(
56
        self,
57
        animal_ids: List = tuple([""]),
58
59
        arena: str = "circular",
        arena_dims: tuple = (1,),
60
        enable_iterative_imputation: bool = None,
61
62
        exclude_bodyparts: List = tuple([""]),
        exp_conditions: dict = None,
63
        interpolate_outliers: bool = True,
64
        interpolation_limit: int = 5,
lucas_miranda's avatar
lucas_miranda committed
65
        interpolation_std: int = 5,
66
        likelihood_tol: float = 0.25,
67
        model: str = "mouse_topview",
68
69
        path: str = deepof.utils.os.path.join("."),
        smooth_alpha: float = 0.99,
70
        table_format: str = "autodetect",
71
        video_format: str = ".mp4",
72
    ):
lucas_miranda's avatar
lucas_miranda committed
73

74
        self.path = path
75
76
        self.video_path = os.path.join(self.path, "Videos")
        self.table_path = os.path.join(self.path, "Tables")
77

78
        self.table_format = table_format
79
        if self.table_format == "autodetect":
80
            ex = [i for i in os.listdir(self.table_path) if not i.startswith(".")][0]
81
82
83
84
            if ".h5" in ex:
                self.table_format = ".h5"
            elif ".csv" in ex:
                self.table_format = ".csv"
85

86
        self.videos = sorted(
87
88
89
            [
                vid
                for vid in deepof.utils.os.listdir(self.video_path)
90
                if vid.endswith(video_format) and not vid.startswith(".")
91
            ]
92
93
        )
        self.tables = sorted(
94
95
96
            [
                tab
                for tab in deepof.utils.os.listdir(self.table_path)
97
                if tab.endswith(self.table_format) and not tab.startswith(".")
98
            ]
99
        )
100
        self.angles = True
101
        self.animal_ids = animal_ids
102
103
        self.arena = arena
        self.arena_dims = arena_dims
104
105
        self.distances = "all"
        self.ego = False
106
        self.exp_conditions = exp_conditions
107
108
109
110
        self.interpolate_outliers = interpolate_outliers
        self.interpolation_limit = interpolation_limit
        self.interpolation_std = interpolation_std
        self.likelihood_tolerance = likelihood_tol
111
        self.scales = self.get_scale
112
113
        self.smooth_alpha = smooth_alpha
        self.subset_condition = None
114
        self.video_format = video_format
115
116
117
118
        if enable_iterative_imputation is None:
            self.enable_iterative_imputation = self.animal_ids == tuple([""])
        else:
            self.enable_iterative_imputation = enable_iterative_imputation
119

lucas_miranda's avatar
lucas_miranda committed
120
121
122
        model_dict = {
            "mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
        }
lucas_miranda's avatar
lucas_miranda committed
123
        self.connectivity = model_dict[model]
124
125
126
127
        self.exclude_bodyparts = exclude_bodyparts
        if self.exclude_bodyparts != tuple([""]):
            for bp in exclude_bodyparts:
                self.connectivity.remove_node(bp)
lucas_miranda's avatar
lucas_miranda committed
128

129
130
    def __str__(self):
        if self.exp_conditions:
131
            return "deepof analysis of {} videos across {} conditions".format(
132
                len(self.videos), len(set(self.exp_conditions.values()))
133
134
            )
        else:
135
            return "deepof analysis of {} videos".format(len(self.videos))
136

137
138
139
140
141
142
143
144
    @property
    def subset_condition(self):
        """Sets a subset condition for the videos to load. If set,
        only the videos with the included pattern will be loaded"""
        return self._subset_condition

    @property
    def distances(self):
145
        """List. If not 'all', sets the body parts among which the
146
147
148
149
150
151
152
153
154
155
156
157
        distances will be computed"""
        return self._distances

    @property
    def ego(self):
        """String, name of a body part. If True, computes only the distances
        between the specified body part and the rest"""
        return self._ego

    @property
    def angles(self):
        """Bool. Toggles angle computation. True by default. If turned off,
158
        enhances performance for big datasets"""
159
160
161
162
163
164
165
166
167
168
        return self._angles

    @property
    def get_scale(self) -> np.array:
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

            scales = []
            for vid_index, _ in enumerate(self.videos):
169
170
171
172
173
174
175
176

                ellipse = deepof.utils.recognize_arena(
                    self.videos,
                    vid_index,
                    path=self.video_path,
                    arena_type=self.arena,
                )[0]

177
                scales.append(
178
                    list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1]]) * 2)
179
180
181
182
183
184
185
186
                    + list(self.arena_dims)
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

187
    def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
188
189
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
190
191
192
193
194
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

195
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
196
            print("Loading trajectories...")
197

lucas_miranda's avatar
lucas_miranda committed
198
199
        tab_dict = {}

200
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
201
202

            tab_dict = {
203
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_hdf(
204
                    deepof.utils.os.path.join(self.table_path, tab), dtype=float
205
206
207
208
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
209

210
211
            tab_dict = {
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_csv(
212
                    deepof.utils.os.path.join(self.table_path, tab),
213
214
215
                    header=[0, 1, 2],
                    index_col=0,
                    dtype=float,
216
                )
217
218
                for tab in self.tables
            }
219
220
221

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
222
        for key, value in tab_dict.items():
223
224
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
225
            lik = value.xs("likelihood", level="coords", axis=1, drop_level=True)
226

lucas_miranda's avatar
lucas_miranda committed
227
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
228
            lik_dict[key] = lik.droplevel("scorer", axis=1)
229
230
231

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
232
233
234
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
235
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
236
237
                cols = tab.columns
                smooth = pd.DataFrame(
238
239
240
                    deepof.utils.smooth_mult_trajectory(
                        np.array(tab), alpha=self.smooth_alpha
                    )
241
                )
lucas_miranda's avatar
lucas_miranda committed
242
                smooth.columns = cols
243
                tab_dict[key] = smooth.iloc[1:, :].reset_index(drop=True)
244

lucas_miranda's avatar
lucas_miranda committed
245
246
        for key, tab in tab_dict.items():
            tab_dict[key] = tab[tab.columns.levels[0][0]]
247

248
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
249
            for key, value in tab_dict.items():
250
251
252
253
254
255
256
257
258
259
260
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
261
262
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
263
264
265

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
266
                tab_dict[key] = tab
267

268
269
270
271
272
273
274
275
276
277
        if self.exclude_bodyparts != tuple([""]):

            for k, value in tab_dict.items():
                temp = value.drop(self.exclude_bodyparts, axis=1, level="bodyparts")
                temp.sort_index(axis=1, inplace=True)
                temp.columns = pd.MultiIndex.from_product(
                    [sorted(list(set([i[j] for i in temp.columns]))) for j in range(2)]
                )
                tab_dict[k] = temp.sort_index(axis=1)

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        if self.interpolate_outliers:

            if verbose:
                print("Interpolating outliers...")

            for k, value in tab_dict.items():
                tab_dict[k] = deepof.utils.interpolate_outliers(
                    value,
                    lik_dict[k],
                    likelihood_tolerance=self.likelihood_tolerance,
                    mode="or",
                    limit=self.interpolation_limit,
                    n_std=self.interpolation_std,
                )

293
294
295
296
297
298
        if self.enable_iterative_imputation:

            if verbose:
                print("Iterative imputation of ocluded bodyparts...")

            for k, value in tab_dict.items():
299
300
301
                imputed = IterativeImputer(
                    max_iter=250, skip_complete=True
                ).fit_transform(value)
302
303
304
305
                tab_dict[k] = pd.DataFrame(
                    imputed, index=value.index, columns=value.columns
                )

lucas_miranda's avatar
lucas_miranda committed
306
        return tab_dict, lik_dict
307

308
309
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
310
        If ego is provided, it only returns distances to a specified bodypart"""
311

lucas_miranda's avatar
lucas_miranda committed
312
313
314
        if verbose:
            print("Computing distances...")

315
        nodes = self.distances
316
        if nodes == "all":
lucas_miranda's avatar
lucas_miranda committed
317
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
318
319

        assert [
lucas_miranda's avatar
lucas_miranda committed
320
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
321
322
323
324
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

325
        distance_dict = {
326
327
328
329
330
            key: deepof.utils.bpart_distance(
                tab,
                scales[i, 1],
                scales[i, 0],
            )
lucas_miranda's avatar
lucas_miranda committed
331
            for i, (key, tab) in enumerate(tab_dict.items())
332
        }
333

lucas_miranda's avatar
lucas_miranda committed
334
335
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
336
337
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
338

339
340
341
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
342
343
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
344
345
346

        return distance_dict

347
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
348
349
350
351
352
353
354
355
356
357
358
359
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
360
361
362
        if verbose:
            print("Computing angles...")

363
        cliques = deepof.utils.nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
364
365
366
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
367
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
368
369
370
371

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
372
373
374
                    deepof.utils.angle_trio(
                        np.array(tab[clique]).reshape([3, tab.shape[0], 2])
                    )
lucas_miranda's avatar
lucas_miranda committed
375
376
377
378
379
380
381
382
383
384
385
386
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
387

388
    def run(self, verbose: bool = True) -> Coordinates:
389
390
391
392
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
393
        angles = None
394
395

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
396
            distances = self.get_distances(tables, verbose)
397

lucas_miranda's avatar
lucas_miranda committed
398
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
399
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
400

lucas_miranda's avatar
lucas_miranda committed
401
        if verbose:
402
403
404
            print("Done!")

        return coordinates(
405
406
            angles=angles,
            animal_ids=self.animal_ids,
lucas_miranda's avatar
lucas_miranda committed
407
408
409
            arena=self.arena,
            arena_dims=self.arena_dims,
            distances=distances,
410
            exp_conditions=self.exp_conditions,
411
            path=self.path,
412
413
414
415
            quality=quality,
            scales=self.scales,
            tables=tables,
            videos=self.videos,
416
417
        )

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    @subset_condition.setter
    def subset_condition(self, value):
        self._subset_condition = value

    @distances.setter
    def distances(self, value):
        self._distances = value

    @ego.setter
    def ego(self, value):
        self._ego = value

    @angles.setter
    def angles(self, value):
        self._angles = value

434
435

class coordinates:
436
437
438
439
440
441
442
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

443
    def __init__(
444
        self,
445
446
        arena: str,
        arena_dims: np.array,
447
        path: str,
448
449
450
451
        quality: dict,
        scales: np.array,
        tables: dict,
        videos: list,
452
        angles: dict = None,
453
        animal_ids: List = tuple([""]),
454
455
        distances: dict = None,
        exp_conditions: dict = None,
456
    ):
457
        self._animal_ids = animal_ids
458
459
        self._arena = arena
        self._arena_dims = arena_dims
460
        self._exp_conditions = exp_conditions
461
        self._path = path
462
463
464
465
466
467
        self._quality = quality
        self._scales = scales
        self._tables = tables
        self._videos = videos
        self.angles = angles
        self.distances = distances
468
469
470
471

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
472
                len(self._videos), len(set(self._exp_conditions.values()))
473
474
            )
        else:
475
            return "deepof analysis of {} videos".format(len(self._videos))
476

477
    def get_coords(
478
479
480
481
482
483
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
484
        align_inplace: bool = False,
485
        propagate_labels: bool = False,
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).
501
502
                - align_inplace (bool): Only valid if align is set. Aligns the vector that goes from the origin to
                the selected body part with the y axis, for all time points.
503
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
504
505
506
507
508

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

509
        tabs = deepof.utils.deepcopy(self._tables)
510

511
512
        if polar:
            for key, tab in tabs.items():
513
                tabs[key] = deepof.utils.tab2polar(tab)
514

515
        if center == "arena":
516
            if self._arena == "circular":
517

518
519
                for i, (key, value) in enumerate(tabs.items()):

520
521
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
522
523
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
524
525
526
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
527
528
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
529
530
531
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
532
533
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
534
535
536
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
537
538
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
539
                        )
540

541
542
543
544
        elif type(center) == str and center != "arena":

            for i, (key, value) in enumerate(tabs.items()):

545
546
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
547
548
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
549
550

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
551
552
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
553
554
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
555
556
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
557
558

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
559
560
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
561

562
                tabs[key] = value.loc[
563
564
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
565

lucas_miranda's avatar
lucas_miranda committed
566
        if speed:
lucas_miranda's avatar
lucas_miranda committed
567
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
568
                vel = deepof.utils.rolling_speed(tab, deriv=speed, center=center)
lucas_miranda's avatar
lucas_miranda committed
569
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
570

571
572
573
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
574
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
575
                ).astype("timedelta64[s]")
576

577
578
        if align:
            assert (
579
                align in list(tabs.values())[0].columns.levels[0]
580
581
582
583
584
585
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
586
587
588
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
589
590
591
592
593
594
595
596
597
598
                tab = tab[columns]
                tabs[key] = tab

                if align_inplace and polar is False:
                    index = tab.columns
                    tab = pd.DataFrame(
                        deepof.utils.align_trajectories(np.array(tab), mode="all")
                    )
                    tab.columns = index
                    tabs[key] = tab
599

600
601
602
603
        if propagate_labels:
            for key, tab in tabs.items():
                tab["pheno"] = self._exp_conditions[key]

604
605
606
607
608
609
610
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
611
            propagate_labels=propagate_labels,
612
613
        )

614
615
616
    def get_distances(
        self, speed: int = 0, length: str = None, propagate_labels: bool = False
    ) -> Table_dict:
617
618
619
620
621
622
623
624
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
625
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
626
627
628
629

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
630

631
        tabs = deepof.utils.deepcopy(self.distances)
lucas_miranda's avatar
lucas_miranda committed
632

lucas_miranda's avatar
lucas_miranda committed
633
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
634
635

            if speed:
lucas_miranda's avatar
lucas_miranda committed
636
                for key, tab in tabs.items():
637
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
638
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
639

640
641
642
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
643
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
644
                    ).astype("timedelta64[s]")
645

646
647
648
649
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

650
            return table_dict(tabs, propagate_labels=propagate_labels, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
651

652
653
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
654
        )  # pragma: no cover
655

656
    def get_angles(
657
658
659
660
661
        self,
        degrees: bool = False,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
662
663
664
665
666
667
668
669
670
671
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
672
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
673
674
675
676

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
677

678
        tabs = deepof.utils.deepcopy(self.angles)
lucas_miranda's avatar
lucas_miranda committed
679

lucas_miranda's avatar
lucas_miranda committed
680
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
681
682
683
684
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
685
                for key, tab in tabs.items():
686
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
687
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
688

689
690
691
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
692
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
693
                    ).astype("timedelta64[s]")
694

695
696
697
698
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

699
            return table_dict(tabs, propagate_labels=propagate_labels, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
700

701
702
703
        raise ValueError(
            "Angles not computed. Read the documentation for more details"
        )  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
704

705
706
707
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

708
        if play:  # pragma: no cover
709
710
711
712
713
714
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
715
716
        """Returns the stored dictionary with experimental conditions per subject"""

717
718
        return self._exp_conditions

719
    def get_quality(self):
720
721
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

722
723
724
725
        return self._quality

    @property
    def get_arenas(self):
726
727
        """Retrieves all available information associated with the arena"""

728
729
        return self._arena, self._arena_dims, self._scales

730
    # noinspection PyDefaultArgument
731
    def rule_based_annotation(
732
        self,
733
        params: Dict = {},
734
735
736
        video_output: bool = False,
        frame_limit: int = np.inf,
        debug: bool = False,
737
    ) -> Table_dict:
738
739
740
        """Annotates coordinates using a simple rule-based pipeline"""

        tag_dict = {}
lucas_miranda's avatar
lucas_miranda committed
741
742
        # noinspection PyTypeChecker
        coords = self.get_coords(center=False)
743
        dists = self.get_distances()
744
        speeds = self.get_coords(speed=1)
745

746
747
        for key in tqdm(self._tables.keys()):

748
            video = [vid for vid in self._videos if key + "DLC" in vid][0]
749
750
751
            tag_dict[key] = deepof.pose_utils.rule_based_tagging(
                list(self._tables.keys()),
                self._videos,
752
                self,
753
                coords,
754
                dists,
755
756
                speeds,
                self._videos.index(video),
757
                arena_type=self._arena,
758
                recog_limit=1,
759
                path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
760
                params=params,
761
            )
762
763
764

        if video_output:  # pragma: no cover

765
766
767
            def output_video(idx):
                """Outputs a single annotated video. Enclosed in a function to enable parallelization"""

768
769
770
771
772
773
                deepof.pose_utils.rule_based_video(
                    self,
                    list(self._tables.keys()),
                    self._videos,
                    list(self._tables.keys()).index(idx),
                    tag_dict[idx],
774
                    debug=debug,
775
776
777
                    frame_limit=frame_limit,
                    recog_limit=1,
                    path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
778
                    params=params,
779
                )
lucas_miranda's avatar
lucas_miranda committed
780
                pbar.update(1)
781

782
783
784
785
786
787
788
789
790
            if type(video_output) == list:
                vid_idxs = video_output
            elif video_output == "all":
                vid_idxs = list(self._tables.keys())
            else:
                raise AttributeError(
                    "Video output must be either 'all' or a list with the names of the videos to render"
                )

lucas_miranda's avatar
lucas_miranda committed
791
            njobs = cpu_count() // 2
lucas_miranda's avatar
lucas_miranda committed
792
            pbar = tqdm(total=len(vid_idxs))
793
794
            with parallel_backend("threading", n_jobs=njobs):
                Parallel()(delayed(output_video)(key) for key in vid_idxs)
lucas_miranda's avatar
lucas_miranda committed
795
            pbar.close()
796

797
798
799
        return table_dict(
            tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
        )
800

801
802
    @staticmethod
    def deep_unsupervised_embedding(
803
        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
804
        batch_size: int = 256,
805
        encoding_size: int = 4,
806
        epochs: int = 35,
807
808
        hparams: dict = None,
        kl_warmup: int = 0,
809
810
        log_history: bool = True,
        log_hparams: bool = False,
811
812
813
814
        loss: str = "ELBO",
        mmd_warmup: int = 0,
        montecarlo_kl: int = 10,
        n_components: int = 25,
815
        output_path: str = ".",
816
817
818
        phenotype_class: float = 0,
        predictor: float = 0,
        pretrained: str = False,
819
        save_checkpoints: bool = False,
820
        save_weights: bool = True,
821
        variational: bool = True,
822
823
    ) -> Tuple:
        """
824
825
        Annotates coordinates using an unsupervised autoencoder.
        Full implementation in deepof.train_utils.deep_unsupervised_embedding
826
827
828
829
830

        Parameters:
            - preprocessed_object (Tuple[np.ndarray]): tuple containing a preprocessed object (X_train,
            y_train, X_test, y_test)
            - encoding_size (int): number of dimensions in the latent space of the autoencoder
831
            - epochs (int): epochs during which to train the models
832
            - batch_size (int): training batch size
833
            - save_checkpoints (bool): if True, training checkpoints are saved to disk. Useful for debugging,
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
            but can make training significantly slower
            - hparams (dict): dictionary to change architecture hyperparameters of the autoencoders
            (see documentation for details)
            - kl_warmup (int): number of epochs over which to increase KL weight linearly
            (default is number of epochs // 4)
            - loss (str): Loss function to use. Currently, 'ELBO', 'MMD' and 'ELBO+MMD' are supported.
            - mmd_warmup (int): number of epochs over which to increase MMD weight linearly
            (default is number of epochs // 4)
            - montecarlo_kl (int): Number of Montecarlo samples used to estimate the KL between latent space and prior
            - n_components (int): Number of components of the Gaussian Mixture in the latent space
            - outpath (str): Path where to save the training loggings
            - phenotype_class (float): weight assigned to phenotype classification. If > 0,
            a classification neural network is appended to the latent space,
            aiming to enforce structure from a set of labels in the encoding.
            - predictor (float): weight assigned to a predictor branch. If > 0, a regression neural network
            is appended to the latent space,
            aiming to predict what happens immediately next in the sequence, which can help with regularization.
            - pretrained (bool): If True, a pretrained set of weights is expected.
            - variational (bool): If True (default) a variational autoencoder is used. If False,
            a simple autoencoder is used for dimensionality reduction

        Returns:
            - return_list (tuple): List containing all relevant trained models for unsupervised prediction.

        """
859

860
        trained_models = deepof.train_utils.autoencoder_fitting(
861
862
863
            preprocessed_object=preprocessed_object,
            batch_size=batch_size,
            encoding_size=encoding_size,
864
            epochs=epochs,
865
866
867
868
869
870
871
872
873
874
875
876
877
            hparams=hparams,
            kl_warmup=kl_warmup,
            log_history=log_history,
            log_hparams=log_hparams,
            loss=loss,
            mmd_warmup=mmd_warmup,
            montecarlo_kl=montecarlo_kl,
            n_components=n_components,
            output_path=output_path,
            phenotype_class=phenotype_class,
            predictor=predictor,
            pretrained=pretrained,
            save_checkpoints=save_checkpoints,
878
            save_weights=save_weights,
879
            variational=variational,
880
        )
881
882

        # returns a list of trained tensorflow models
883
        return trained_models
884

885
886

class table_dict(dict):
887
888
889
890
891
892
893
894
895
    """

    Main class for storing a single dataset as a dictionary with individuals as keys and pandas.DataFrames as values.
    Includes methods for generating training and testing datasets for the autoencoders.

    """

    def __init__(
        self,
896
        tabs: Dict,
897
898
899
900
901
        typ: str,
        arena: str = None,
        arena_dims: np.array = None,
        center: str = None,
        polar: bool = None,
902
        propagate_labels: bool = False,
903
    ):
904
905
906
907
908
909
        super().__init__(tabs)
        self._type = typ
        self._center = center
        self._polar = polar
        self._arena = arena
        self._arena_dims = arena_dims
910
        self._propagate_labels = propagate_labels
911

912
    def filter_videos(self, keys: list) -> Table_dict:
913
        """Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
914
        for selecting data coming from videos of a specified condition."""
915
916
917

        assert np.all([k in self.keys() for k in keys]), "Invalid keys selected"

918
919
920
        return table_dict(
            {k: value for k, value in self.items() if k in keys}, self._type
        )
921

lucas_miranda's avatar
lucas_miranda committed
922
    # noinspection PyTypeChecker
923
    def plot_heatmaps(
924
925
926
927
928
929
930
        self,
        bodyparts: list,
        xlim: float = None,
        ylim: float = None,
        save: bool = False,
        i: int = 0,
        dpi: int = 100,
931
932
    ) -> plt.figure:
        """Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
933
934
935

        if self._type != "coords" or self._polar:
            raise NotImplementedError(
lucas_miranda's avatar
lucas_miranda committed
936
937
                "Heatmaps only available for cartesian coordinates. "
                "Set polar to False in get_coordinates and try again"
938
            )  # pragma: no cover
939

940
        if not self._center:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
941
            warnings.warn("Heatmaps look better if you center the data")
942
943
944

        if self._arena == "circular":

945
            heatmaps = deepof.visuals.plot_heatmap(
946
947
948
949
950
951
                list(self.values())[i],
                bodyparts,
                xlim=xlim,
                ylim=ylim,
                save=save,
                dpi=dpi,
952
953
            )

lucas_miranda's avatar
lucas_miranda committed
954
955
            return heatmaps

956
    def get_training_set(
957
958
959
        self,
        test_videos: int = 0,
        encode_labels: bool = True,
960
    ) -> Tuple[np.ndarray, list, Union[np.ndarray, list], list]:
961
962
        """Generates training and test sets as numpy.array objects for model training"""

963
        # Padding of videos with slightly different lengths
lucas_miranda's avatar
lucas_miranda committed
964
        raw_data = np.array([np.array(v) for v in self.values()], dtype=object)
965
        if self._propagate_labels:
lucas_miranda's avatar