data.py 52.6 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

lucas_miranda's avatar
lucas_miranda committed
17
18
import os
import warnings
19
from collections import defaultdict
lucas_miranda's avatar
lucas_miranda committed
20
from multiprocessing import cpu_count
lucas_miranda's avatar
lucas_miranda committed
21
22
23
24
25
from typing import Dict, List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
26
import seaborn as sns
lucas_miranda's avatar
lucas_miranda committed
27
28
import tensorflow as tf
from joblib import delayed, Parallel, parallel_backend
29
from pkg_resources import resource_filename
30
31
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
lucas_miranda's avatar
lucas_miranda committed
32
from sklearn.experimental import enable_iterative_imputer
33
from sklearn.impute import IterativeImputer
34
from sklearn.manifold import TSNE
35
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
36
from tqdm import tqdm
lucas_miranda's avatar
lucas_miranda committed
37

38
import deepof.models
39
import deepof.pose_utils
lucas_miranda's avatar
lucas_miranda committed
40
import deepof.train_utils
41
42
import deepof.utils
import deepof.visuals
43

44
45
# Remove excessive logging from tensorflow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
46

47
# DEFINE CUSTOM ANNOTATED TYPES #
48
49
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any)
50

51

52
# CLASSES FOR PREPROCESSING AND DATA WRANGLING
53

54

55
class project:
lucas_miranda's avatar
lucas_miranda committed
56
57
    """

58
59
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
60
61

    """
62
63

    def __init__(
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        self,
        animal_ids: List = tuple([""]),
        arena: str = "circular",
        arena_detection: str = "rule-based",
        arena_dims: tuple = (1,),
        enable_iterative_imputation: bool = None,
        exclude_bodyparts: List = tuple([""]),
        exp_conditions: dict = None,
        interpolate_outliers: bool = True,
        interpolation_limit: int = 5,
        interpolation_std: int = 5,
        likelihood_tol: float = 0.25,
        model: str = "mouse_topview",
        path: str = deepof.utils.os.path.join("."),
        smooth_alpha: float = 0.99,
        table_format: str = "autodetect",
        video_format: str = ".mp4",
81
    ):
lucas_miranda's avatar
lucas_miranda committed
82

83
        # Set working paths
84
        self.path = path
85
86
        self.video_path = os.path.join(self.path, "Videos")
        self.table_path = os.path.join(self.path, "Tables")
87
        self.trained_path = resource_filename(__name__, "trained_models")
88

89
        # Detect files to load from disk
90
        self.table_format = table_format
91
        if self.table_format == "autodetect":
92
            ex = [i for i in os.listdir(self.table_path) if not i.startswith(".")][0]
93
94
95
96
            if ".h5" in ex:
                self.table_format = ".h5"
            elif ".csv" in ex:
                self.table_format = ".csv"
97
        self.videos = sorted(
98
99
100
            [
                vid
                for vid in deepof.utils.os.listdir(self.video_path)
101
                if vid.endswith(video_format) and not vid.startswith(".")
102
            ]
103
104
        )
        self.tables = sorted(
105
106
107
            [
                tab
                for tab in deepof.utils.os.listdir(self.table_path)
108
                if tab.endswith(self.table_format) and not tab.startswith(".")
109
            ]
110
        )
111
112
113
114
        assert len(self.videos) == len(
            self.tables
        ), "Unequal number of videos and tables. Please check your file structure"

115
        # Loads arena details and (if needed) detection models
116
        self.arena = arena
117
        self.arena_detection = arena_detection
118
        self.arena_dims = arena_dims
119
120
121
122
123
124
125
126
127
128
        self.ellipse_detection = None
        if arena == "circular" and arena_detection == "cnn":
            self.ellipse_detection = tf.keras.models.load_model(
                [
                    os.path.join(self.trained_path, i)
                    for i in os.listdir(self.trained_path)
                    if i.startswith("elliptical")
                ][0]
            )

129
        self.scales, self.arena_params, self.video_resolution = self.get_arena
130
131
132
133

        # Set the rest of the init parameters
        self.angles = True
        self.animal_ids = animal_ids
134
135
        self.distances = "all"
        self.ego = False
136
        self.exp_conditions = exp_conditions
137
138
139
140
        self.interpolate_outliers = interpolate_outliers
        self.interpolation_limit = interpolation_limit
        self.interpolation_std = interpolation_std
        self.likelihood_tolerance = likelihood_tol
141
142
        self.smooth_alpha = smooth_alpha
        self.subset_condition = None
143
        self.video_format = video_format
144
145
146
147
        if enable_iterative_imputation is None:
            self.enable_iterative_imputation = self.animal_ids == tuple([""])
        else:
            self.enable_iterative_imputation = enable_iterative_imputation
148

lucas_miranda's avatar
lucas_miranda committed
149
150
151
        model_dict = {
            "mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
        }
lucas_miranda's avatar
lucas_miranda committed
152
        self.connectivity = model_dict[model]
153
154
155
156
        self.exclude_bodyparts = exclude_bodyparts
        if self.exclude_bodyparts != tuple([""]):
            for bp in exclude_bodyparts:
                self.connectivity.remove_node(bp)
lucas_miranda's avatar
lucas_miranda committed
157

158
159
    def __str__(self):
        if self.exp_conditions:
160
            return "deepof analysis of {} videos across {} conditions".format(
161
                len(self.videos), len(set(self.exp_conditions.values()))
162
163
            )
        else:
164
            return "deepof analysis of {} videos".format(len(self.videos))
165

166
167
168
169
170
171
172
173
    @property
    def subset_condition(self):
        """Sets a subset condition for the videos to load. If set,
        only the videos with the included pattern will be loaded"""
        return self._subset_condition

    @property
    def distances(self):
174
        """List. If not 'all', sets the body parts among which the
175
176
177
178
179
180
181
182
183
184
185
186
        distances will be computed"""
        return self._distances

    @property
    def ego(self):
        """String, name of a body part. If True, computes only the distances
        between the specified body part and the rest"""
        return self._ego

    @property
    def angles(self):
        """Bool. Toggles angle computation. True by default. If turned off,
187
        enhances performance for big datasets"""
188
189
190
        return self._angles

    @property
191
    def get_arena(self) -> np.array:
192
193
        """Returns the arena as recognised from the videos"""

194
195
        scales = []
        arena_params = []
196
        video_resolution = []
197

198
199
200
        if self.arena in ["circular"]:

            for vid_index, _ in enumerate(self.videos):
201
                ellipse, h, w = deepof.utils.recognize_arena(
202
203
204
205
                    self.videos,
                    vid_index,
                    path=self.video_path,
                    arena_type=self.arena,
206
207
                    detection_mode=self.arena_detection,
                    cnn_model=self.ellipse_detection,
208
                )
209

210
                scales.append(
211
                    list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1] * 2]))
212
213
                    + list(self.arena_dims)
                )
214
                arena_params.append(ellipse)
215
                video_resolution.append((h, w))
216
217
218
219

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

220
        return np.array(scales), arena_params, video_resolution
221

222
    def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
223
224
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
225
226
227
228
229
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

230
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
231
            print("Loading trajectories...")
232

lucas_miranda's avatar
lucas_miranda committed
233
234
        tab_dict = {}

235
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
236
237

            tab_dict = {
238
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_hdf(
239
                    deepof.utils.os.path.join(self.table_path, tab), dtype=float
240
241
242
243
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
244

245
246
            tab_dict = {
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_csv(
247
                    deepof.utils.os.path.join(self.table_path, tab),
248
249
250
                    header=[0, 1, 2],
                    index_col=0,
                    dtype=float,
251
                )
252
253
                for tab in self.tables
            }
254
255
256

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
257
        for key, value in tab_dict.items():
258
259
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
260
            lik = value.xs("likelihood", level="coords", axis=1, drop_level=True)
261

lucas_miranda's avatar
lucas_miranda committed
262
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
263
            lik_dict[key] = lik.droplevel("scorer", axis=1)
264
265
266

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
267
268
269
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
270
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
271
272
                cols = tab.columns
                smooth = pd.DataFrame(
273
274
275
                    deepof.utils.smooth_mult_trajectory(
                        np.array(tab), alpha=self.smooth_alpha
                    )
276
                )
lucas_miranda's avatar
lucas_miranda committed
277
                smooth.columns = cols
278
                tab_dict[key] = smooth.iloc[1:, :].reset_index(drop=True)
279

lucas_miranda's avatar
lucas_miranda committed
280
        for key, tab in tab_dict.items():
281
            tab_dict[key] = tab.loc[:, tab.columns.levels[0][0]]
282

283
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
284
            for key, value in tab_dict.items():
285
286
287
288
289
290
291
292
293
294
295
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
296
297
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
298
299
300

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
301
                tab_dict[key] = tab
302

303
304
305
306
307
308
309
310
311
312
        if self.exclude_bodyparts != tuple([""]):

            for k, value in tab_dict.items():
                temp = value.drop(self.exclude_bodyparts, axis=1, level="bodyparts")
                temp.sort_index(axis=1, inplace=True)
                temp.columns = pd.MultiIndex.from_product(
                    [sorted(list(set([i[j] for i in temp.columns]))) for j in range(2)]
                )
                tab_dict[k] = temp.sort_index(axis=1)

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        if self.interpolate_outliers:

            if verbose:
                print("Interpolating outliers...")

            for k, value in tab_dict.items():
                tab_dict[k] = deepof.utils.interpolate_outliers(
                    value,
                    lik_dict[k],
                    likelihood_tolerance=self.likelihood_tolerance,
                    mode="or",
                    limit=self.interpolation_limit,
                    n_std=self.interpolation_std,
                )

328
329
330
331
332
333
        if self.enable_iterative_imputation:

            if verbose:
                print("Iterative imputation of ocluded bodyparts...")

            for k, value in tab_dict.items():
334
                imputed = IterativeImputer(
335
                    max_iter=5, skip_complete=True
336
                ).fit_transform(value)
337
338
339
340
                tab_dict[k] = pd.DataFrame(
                    imputed, index=value.index, columns=value.columns
                )

lucas_miranda's avatar
lucas_miranda committed
341
        return tab_dict, lik_dict
342

343
344
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
345
        If ego is provided, it only returns distances to a specified bodypart"""
346

lucas_miranda's avatar
lucas_miranda committed
347
348
349
        if verbose:
            print("Computing distances...")

350
        nodes = self.distances
351
        if nodes == "all":
lucas_miranda's avatar
lucas_miranda committed
352
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
353
354

        assert [
lucas_miranda's avatar
lucas_miranda committed
355
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
356
357
358
359
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

360
        distance_dict = {
361
362
363
364
365
            key: deepof.utils.bpart_distance(
                tab,
                scales[i, 1],
                scales[i, 0],
            )
lucas_miranda's avatar
lucas_miranda committed
366
            for i, (key, tab) in enumerate(tab_dict.items())
367
        }
368

lucas_miranda's avatar
lucas_miranda committed
369
370
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
371
372
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
373

374
375
376
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
377
378
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
379
380
381

        return distance_dict

382
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
383
384
385
386
387
388
389
390
391
392
393
394
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
395
396
397
        if verbose:
            print("Computing angles...")

398
        cliques = deepof.utils.nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
399
400
401
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
402
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
403
404
405
406

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
407
408
409
                    deepof.utils.angle_trio(
                        np.array(tab[clique]).reshape([3, tab.shape[0], 2])
                    )
lucas_miranda's avatar
lucas_miranda committed
410
411
412
413
414
415
416
417
418
419
420
421
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
422

423
    def run(self, verbose: bool = True) -> Coordinates:
424
425
426
427
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
428
        angles = None
429
430

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
431
            distances = self.get_distances(tables, verbose)
432

lucas_miranda's avatar
lucas_miranda committed
433
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
434
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
435

lucas_miranda's avatar
lucas_miranda committed
436
        if verbose:
437
438
439
            print("Done!")

        return coordinates(
440
441
            angles=angles,
            animal_ids=self.animal_ids,
lucas_miranda's avatar
lucas_miranda committed
442
            arena=self.arena,
443
            arena_detection=self.arena_detection,
lucas_miranda's avatar
lucas_miranda committed
444
445
            arena_dims=self.arena_dims,
            distances=distances,
446
            exp_conditions=self.exp_conditions,
447
            path=self.path,
448
449
            quality=quality,
            scales=self.scales,
450
            arena_params=self.arena_params,
451
452
            tables=tables,
            videos=self.videos,
453
            video_resolution=self.video_resolution,
454
455
        )

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    @subset_condition.setter
    def subset_condition(self, value):
        self._subset_condition = value

    @distances.setter
    def distances(self, value):
        self._distances = value

    @ego.setter
    def ego(self, value):
        self._ego = value

    @angles.setter
    def angles(self, value):
        self._angles = value

472
473

class coordinates:
474
475
476
477
478
479
480
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

481
    def __init__(
482
483
484
485
486
487
488
        self,
        arena: str,
        arena_detection: str,
        arena_dims: np.array,
        path: str,
        quality: dict,
        scales: np.array,
489
        arena_params: List,
490
        tables: dict,
491
492
        videos: List,
        video_resolution: List,
493
494
495
496
        angles: dict = None,
        animal_ids: List = tuple([""]),
        distances: dict = None,
        exp_conditions: dict = None,
497
    ):
498
        self._animal_ids = animal_ids
499
        self._arena = arena
500
        self._arena_detection = arena_detection
501
        self._arena_params = arena_params
502
        self._arena_dims = arena_dims
503
        self._exp_conditions = exp_conditions
504
        self._path = path
505
506
507
508
        self._quality = quality
        self._scales = scales
        self._tables = tables
        self._videos = videos
509
        self._video_resolution = video_resolution
510
511
        self.angles = angles
        self.distances = distances
512
513
514
515

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
516
                len(self._videos), len(set(self._exp_conditions.values()))
517
518
            )
        else:
519
            return "deepof analysis of {} videos".format(len(self._videos))
520

521
    def get_coords(
522
523
524
525
526
527
528
529
530
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
        align_inplace: bool = False,
        propagate_labels: bool = False,
        propagate_annotations: Dict = False,
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).
546
547
                - align_inplace (bool): Only valid if align is set. Aligns the vector that goes from the origin to
                the selected body part with the y axis, for all time points.
548
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
549
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
550
551
                are propagated through the training dataset. This can be used for regularising the latent space based
                on already known traits.
552
553
554
555
556

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

557
        tabs = deepof.utils.deepcopy(self._tables)
558

559
560
        if polar:
            for key, tab in tabs.items():
561
                tabs[key] = deepof.utils.tab2polar(tab)
562

563
        if center == "arena":
564
            if self._arena == "circular":
565

566
567
                for i, (key, value) in enumerate(tabs.items()):

568
569
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
570
571
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
572
573
574
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
575
576
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
577
578
579
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
580
581
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
582
583
584
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
585
586
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
587
                        )
588

589
        elif isinstance(center, str) and center != "arena":
590
591
592

            for i, (key, value) in enumerate(tabs.items()):

593
594
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
595
596
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
597
598

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
599
600
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
601
602
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
603
604
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
605
606

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
607
608
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
609

610
                tabs[key] = value.loc[
611
612
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
613

lucas_miranda's avatar
lucas_miranda committed
614
        if speed:
lucas_miranda's avatar
lucas_miranda committed
615
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
616
                vel = deepof.utils.rolling_speed(tab, deriv=speed, center=center)
lucas_miranda's avatar
lucas_miranda committed
617
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
618

619
620
621
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
622
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
623
                ).astype("timedelta64[s]")
624

625
626
        if align:
            assert (
627
                align in list(tabs.values())[0].columns.levels[0]
628
629
630
631
632
633
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
634
635
636
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
637
638
639
640
641
642
643
644
645
646
                tab = tab[columns]
                tabs[key] = tab

                if align_inplace and polar is False:
                    index = tab.columns
                    tab = pd.DataFrame(
                        deepof.utils.align_trajectories(np.array(tab), mode="all")
                    )
                    tab.columns = index
                    tabs[key] = tab
647

648
649
        if propagate_labels:
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
650
                tab.loc[:, "pheno"] = self._exp_conditions[key]
651

652
653
654
655
656
657
658
        if propagate_annotations:
            annotations = list(propagate_annotations.values())[0].columns

            for key, tab in tabs.items():
                for ann in annotations:
                    tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

659
660
661
662
663
664
665
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
666
            propagate_labels=propagate_labels,
667
            propagate_annotations=propagate_annotations,
668
669
        )

670
    def get_distances(
671
672
673
674
675
        self,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
        propagate_annotations: Dict = False,
676
    ) -> Table_dict:
677
678
679
680
681
682
683
684
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
685
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
686
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
687
688
                are propagated through the training dataset. This can be used for regularising the latent space based
                on already known traits.
689
690
691
692

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
693

694
        tabs = deepof.utils.deepcopy(self.distances)
lucas_miranda's avatar
lucas_miranda committed
695

lucas_miranda's avatar
lucas_miranda committed
696
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
697
698

            if speed:
lucas_miranda's avatar
lucas_miranda committed
699
                for key, tab in tabs.items():
700
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
701
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
702

703
704
705
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
706
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
707
                    ).astype("timedelta64[s]")
708

709
710
            if propagate_labels:
                for key, tab in tabs.items():
711
                    tab.loc[:, "pheno"] = self._exp_conditions[key]
712

713
714
715
716
717
718
719
720
721
722
            if propagate_annotations:
                annotations = list(propagate_annotations.values())[0].columns

                for key, tab in tabs.items():
                    for ann in annotations:
                        tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

            return table_dict(
                tabs,
                propagate_labels=propagate_labels,
723
                propagate_annotations=propagate_annotations,
724
725
                typ="dists",
            )
lucas_miranda's avatar
lucas_miranda committed
726

727
728
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
729
        )  # pragma: no cover
730

731
    def get_angles(
732
733
734
735
736
737
        self,
        degrees: bool = False,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
        propagate_annotations: Dict = False,
738
739
740
741
742
743
744
745
746
747
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
748
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
749
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
750
751
                are propagated through the training dataset. This can be used for regularising the latent space based
                on already known traits.
752
753
754
755

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
756

757
        tabs = deepof.utils.deepcopy(self.angles)
lucas_miranda's avatar
lucas_miranda committed
758

lucas_miranda's avatar
lucas_miranda committed
759
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
760
761
762
763
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
764
                for key, tab in tabs.items():
765
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
766
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
767

768
769
770
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
771
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
772
                    ).astype("timedelta64[s]")
773

774
775
776
777
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

778
779
780
781
782
783
784
785
786
787
            if propagate_annotations:
                annotations = list(propagate_annotations.values())[0].columns

                for key, tab in tabs.items():
                    for ann in annotations:
                        tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

            return table_dict(
                tabs,
                propagate_labels=propagate_labels,
788
                propagate_annotations=propagate_annotations,
789
790
                typ="angles",
            )
lucas_miranda's avatar
lucas_miranda committed
791

792
793
794
        raise ValueError(
            "Angles not computed. Read the documentation for more details"
        )  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
795

796
797
798
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

799
        if play:  # pragma: no cover
800
801
802
803
804
805
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
806
807
        """Returns the stored dictionary with experimental conditions per subject"""

808
809
        return self._exp_conditions

810
    def get_quality(self):
811
812
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

813
814
815
816
        return self._quality

    @property
    def get_arenas(self):
817
818
        """Retrieves all available information associated with the arena"""

819
820
        return self._arena, self._arena_dims, self._scales

821
    # noinspection PyDefaultArgument
822
    def rule_based_annotation(
823
824
825
826
827
        self,
        params: Dict = {},
        video_output: bool = False,
        frame_limit: int = np.inf,
        debug: bool = False,
828
        n_jobs: int = 1,
829
        propagate_labels: bool = False,
830
    ) -> Table_dict:
831
832
833
        """Annotates coordinates using a simple rule-based pipeline"""

        tag_dict = {}
lucas_miranda's avatar
lucas_miranda committed
834
        coords = self.get_coords(center=False)
835
        dists = self.get_distances()
836
        speeds = self.get_coords(speed=1)
837

838
        # noinspection PyTypeChecker
839
        for key in tqdm(self._tables.keys()):
840
            tag_dict[key] = deepof.pose_utils.rule_based_tagging(
841
                self,
842
843
844
845
                coords=coords,
                dists=dists,
                speeds=speeds,
                video=[vid for vid in self._videos if key + "DLC" in vid][0],
lucas_miranda's avatar
lucas_miranda committed
846
                params=params,
847
            )
848

849
850
851
852
        if propagate_labels:
            for key, tab in tag_dict.items():
                tab["pheno"] = self._exp_conditions[key]

853
854
        if video_output:  # pragma: no cover

855
856
857
            def output_video(idx):
                """Outputs a single annotated video. Enclosed in a function to enable parallelization"""

858
859
                deepof.pose_utils.rule_based_video(
                    self,
860
                    tag_dict=tag_dict[idx],
861
                    vid_index=list(self._tables.keys()).index(idx),
862
                    debug=debug,
863
                    frame_limit=frame_limit,
lucas_miranda's avatar
lucas_miranda committed
864
                    params=params,
865
                )
lucas_miranda's avatar
lucas_miranda committed
866
                pbar.update(1)
867

868
            if isinstance(video_output, list):
869
870
871
872
873
874
875
876
                vid_idxs = video_output
            elif video_output == "all":
                vid_idxs = list(self._tables.keys())
            else:
                raise AttributeError(
                    "Video output must be either 'all' or a list with the names of the videos to render"
                )

lucas_miranda's avatar
lucas_miranda committed
877
            pbar = tqdm(total=len(vid_idxs))
878
            with parallel_backend("threading", n_jobs=n_jobs):
879
                Parallel()(delayed(output_video)(key) for key in vid_idxs)
lucas_miranda's avatar
lucas_miranda committed
880
            pbar.close()
881

882
        return table_dict(
883
884
885
886
887
            tag_dict,
            typ="rule-based",
            arena=self._arena,
            arena_dims=self._arena_dims,
            propagate_labels=propagate_labels,
888
        )
889

890
891
    @staticmethod
    def deep_unsupervised_embedding(
892
893
894
        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
        batch_size: int = 256,
        encoding_size: int = 4,
895
        epochs: int = 50,
896
        hparams: dict = None,
897
        kl_annealing_mode: str = "linear",
898
899
900
901
        kl_warmup: int = 0,
        log_history: bool = True,
        log_hparams: bool = False,
        loss: str = "ELBO",
902
        mmd_annealing_mode: str = "linear",
903
904
905
        mmd_warmup: int = 0,
        montecarlo_kl: int = 10,
        n_components: int = 25,
906
        overlap_loss: float = 0,
907
        output_path: str = ".",
908
909
910
        next_sequence_prediction: float = 0,
        phenotype_prediction: float = 0,
        rule_based_prediction: float = 0,
911
912
913
914
915
916
917
        pretrained: str = False,
        save_checkpoints: bool = False,
        save_weights: bool = True,
        reg_cat_clusters: bool = False,
        reg_cluster_variance: bool = False,
        entropy_samples: int = 10000,
        entropy_knn: int = 100,
918
        input_type: str = False,
919
        run: int = 0,
920
        strategy: tf.distribute.Strategy = tf.distribute.MirroredStrategy(),
921
922
    ) -> Tuple:
        """
923
924
        Annotates coordinates using an unsupervised autoencoder.
        Full implementation in deepof.train_utils.deep_unsupervised_embedding
925
926
927
928
929

        Parameters:
            - preprocessed_object (Tuple[np.ndarray]): tuple containing a preprocessed object (X_train,
            y_train, X_test, y_test)
            - encoding_size (int): number of dimensions in the latent space of the autoencoder
930