data.py 48 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

17
from collections import defaultdict
18
from joblib import delayed, Parallel, parallel_backend
19
from typing import Dict, List, Tuple, Union
lucas_miranda's avatar
lucas_miranda committed
20
from multiprocessing import cpu_count
21
22
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
23
24
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
25
from sklearn.manifold import TSNE
26
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
27
from tqdm import tqdm
28
import deepof.models
29
30
31
import deepof.pose_utils
import deepof.utils
import deepof.visuals
32
import deepof.train_utils
33
34
import matplotlib.pyplot as plt
import numpy as np
35
import os
36
37
import pandas as pd
import warnings
38

39
40
# Remove excessive logging from tensorflow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
41

42
# DEFINE CUSTOM ANNOTATED TYPES #
43
44
Coordinates = deepof.utils.Newisinstance("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.Newisinstance("Table_dict", deepof.utils.Any)
45
46

# CLASSES FOR PREPROCESSING AND DATA WRANGLING
47

48

49
class project:
lucas_miranda's avatar
lucas_miranda committed
50
51
    """

52
53
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
54
55

    """
56
57

    def __init__(
58
        self,
59
        animal_ids: List = tuple([""]),
60
61
        arena: str = "circular",
        arena_dims: tuple = (1,),
62
        enable_iterative_imputation: bool = None,
63
64
        exclude_bodyparts: List = tuple([""]),
        exp_conditions: dict = None,
65
        interpolate_outliers: bool = True,
66
        interpolation_limit: int = 5,
lucas_miranda's avatar
lucas_miranda committed
67
        interpolation_std: int = 5,
68
        likelihood_tol: float = 0.25,
69
        model: str = "mouse_topview",
70
71
        path: str = deepof.utils.os.path.join("."),
        smooth_alpha: float = 0.99,
72
        table_format: str = "autodetect",
73
        video_format: str = ".mp4",
74
    ):
lucas_miranda's avatar
lucas_miranda committed
75

76
        self.path = path
77
78
        self.video_path = os.path.join(self.path, "Videos")
        self.table_path = os.path.join(self.path, "Tables")
79

80
        self.table_format = table_format
81
        if self.table_format == "autodetect":
82
            ex = [i for i in os.listdir(self.table_path) if not i.startswith(".")][0]
83
84
85
86
            if ".h5" in ex:
                self.table_format = ".h5"
            elif ".csv" in ex:
                self.table_format = ".csv"
87

88
        self.videos = sorted(
89
90
91
            [
                vid
                for vid in deepof.utils.os.listdir(self.video_path)
92
                if vid.endswith(video_format) and not vid.startswith(".")
93
            ]
94
95
        )
        self.tables = sorted(
96
97
98
            [
                tab
                for tab in deepof.utils.os.listdir(self.table_path)
99
                if tab.endswith(self.table_format) and not tab.startswith(".")
100
            ]
101
        )
102
103
104
105
106

        assert len(self.videos) == len(
            self.tables
        ), "Unequal number of videos and tables. Please check your file structure"

107
        self.angles = True
108
        self.animal_ids = animal_ids
109
110
        self.arena = arena
        self.arena_dims = arena_dims
111
112
        self.distances = "all"
        self.ego = False
113
        self.exp_conditions = exp_conditions
114
115
116
117
        self.interpolate_outliers = interpolate_outliers
        self.interpolation_limit = interpolation_limit
        self.interpolation_std = interpolation_std
        self.likelihood_tolerance = likelihood_tol
118
        self.scales = self.get_scale
119
120
        self.smooth_alpha = smooth_alpha
        self.subset_condition = None
121
        self.video_format = video_format
122
123
124
125
        if enable_iterative_imputation is None:
            self.enable_iterative_imputation = self.animal_ids == tuple([""])
        else:
            self.enable_iterative_imputation = enable_iterative_imputation
126

lucas_miranda's avatar
lucas_miranda committed
127
128
129
        model_dict = {
            "mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
        }
lucas_miranda's avatar
lucas_miranda committed
130
        self.connectivity = model_dict[model]
131
132
133
134
        self.exclude_bodyparts = exclude_bodyparts
        if self.exclude_bodyparts != tuple([""]):
            for bp in exclude_bodyparts:
                self.connectivity.remove_node(bp)
lucas_miranda's avatar
lucas_miranda committed
135

136
137
    def __str__(self):
        if self.exp_conditions:
138
            return "deepof analysis of {} videos across {} conditions".format(
139
                len(self.videos), len(set(self.exp_conditions.values()))
140
141
            )
        else:
142
            return "deepof analysis of {} videos".format(len(self.videos))
143

144
145
146
147
148
149
150
151
    @property
    def subset_condition(self):
        """Sets a subset condition for the videos to load. If set,
        only the videos with the included pattern will be loaded"""
        return self._subset_condition

    @property
    def distances(self):
152
        """List. If not 'all', sets the body parts among which the
153
154
155
156
157
158
159
160
161
162
163
164
        distances will be computed"""
        return self._distances

    @property
    def ego(self):
        """String, name of a body part. If True, computes only the distances
        between the specified body part and the rest"""
        return self._ego

    @property
    def angles(self):
        """Bool. Toggles angle computation. True by default. If turned off,
165
        enhances performance for big datasets"""
166
167
168
169
170
171
172
173
174
175
        return self._angles

    @property
    def get_scale(self) -> np.array:
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

            scales = []
            for vid_index, _ in enumerate(self.videos):
176
177
178
179
180
181
182
183

                ellipse = deepof.utils.recognize_arena(
                    self.videos,
                    vid_index,
                    path=self.video_path,
                    arena_type=self.arena,
                )[0]

184
                scales.append(
185
                    list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1]]) * 2)
186
187
188
189
190
191
192
193
                    + list(self.arena_dims)
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

194
    def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
195
196
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
197
198
199
200
201
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

202
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
203
            print("Loading trajectories...")
204

lucas_miranda's avatar
lucas_miranda committed
205
206
        tab_dict = {}

207
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
208
209

            tab_dict = {
210
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_hdf(
211
                    deepof.utils.os.path.join(self.table_path, tab), dtype=float
212
213
214
215
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
216

217
218
            tab_dict = {
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_csv(
219
                    deepof.utils.os.path.join(self.table_path, tab),
220
221
222
                    header=[0, 1, 2],
                    index_col=0,
                    dtype=float,
223
                )
224
225
                for tab in self.tables
            }
226
227
228

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
229
        for key, value in tab_dict.items():
230
231
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
232
            lik = value.xs("likelihood", level="coords", axis=1, drop_level=True)
233

lucas_miranda's avatar
lucas_miranda committed
234
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
235
            lik_dict[key] = lik.droplevel("scorer", axis=1)
236
237
238

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
239
240
241
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
242
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
243
244
                cols = tab.columns
                smooth = pd.DataFrame(
245
246
247
                    deepof.utils.smooth_mult_trajectory(
                        np.array(tab), alpha=self.smooth_alpha
                    )
248
                )
lucas_miranda's avatar
lucas_miranda committed
249
                smooth.columns = cols
250
                tab_dict[key] = smooth.iloc[1:, :].reset_index(drop=True)
251

lucas_miranda's avatar
lucas_miranda committed
252
        for key, tab in tab_dict.items():
253
            tab_dict[key] = tab.loc[:, tab.columns.levels[0][0]]
254

255
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
256
            for key, value in tab_dict.items():
257
258
259
260
261
262
263
264
265
266
267
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
268
269
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
270
271
272

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
273
                tab_dict[key] = tab
274

275
276
277
278
279
280
281
282
283
284
        if self.exclude_bodyparts != tuple([""]):

            for k, value in tab_dict.items():
                temp = value.drop(self.exclude_bodyparts, axis=1, level="bodyparts")
                temp.sort_index(axis=1, inplace=True)
                temp.columns = pd.MultiIndex.from_product(
                    [sorted(list(set([i[j] for i in temp.columns]))) for j in range(2)]
                )
                tab_dict[k] = temp.sort_index(axis=1)

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        if self.interpolate_outliers:

            if verbose:
                print("Interpolating outliers...")

            for k, value in tab_dict.items():
                tab_dict[k] = deepof.utils.interpolate_outliers(
                    value,
                    lik_dict[k],
                    likelihood_tolerance=self.likelihood_tolerance,
                    mode="or",
                    limit=self.interpolation_limit,
                    n_std=self.interpolation_std,
                )

300
301
302
303
304
305
        if self.enable_iterative_imputation:

            if verbose:
                print("Iterative imputation of ocluded bodyparts...")

            for k, value in tab_dict.items():
306
                imputed = IterativeImputer(
307
                    max_iter=1000, skip_complete=True
308
                ).fit_transform(value)
309
310
311
312
                tab_dict[k] = pd.DataFrame(
                    imputed, index=value.index, columns=value.columns
                )

lucas_miranda's avatar
lucas_miranda committed
313
        return tab_dict, lik_dict
314

315
316
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
317
        If ego is provided, it only returns distances to a specified bodypart"""
318

lucas_miranda's avatar
lucas_miranda committed
319
320
321
        if verbose:
            print("Computing distances...")

322
        nodes = self.distances
323
        if nodes == "all":
lucas_miranda's avatar
lucas_miranda committed
324
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
325
326

        assert [
lucas_miranda's avatar
lucas_miranda committed
327
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
328
329
330
331
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

332
        distance_dict = {
333
334
335
336
337
            key: deepof.utils.bpart_distance(
                tab,
                scales[i, 1],
                scales[i, 0],
            )
lucas_miranda's avatar
lucas_miranda committed
338
            for i, (key, tab) in enumerate(tab_dict.items())
339
        }
340

lucas_miranda's avatar
lucas_miranda committed
341
342
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
343
344
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
345

346
347
348
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
349
350
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
351
352
353

        return distance_dict

354
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
355
356
357
358
359
360
361
362
363
364
365
366
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
367
368
369
        if verbose:
            print("Computing angles...")

370
        cliques = deepof.utils.nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
371
372
373
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
374
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
375
376
377
378

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
379
380
381
                    deepof.utils.angle_trio(
                        np.array(tab[clique]).reshape([3, tab.shape[0], 2])
                    )
lucas_miranda's avatar
lucas_miranda committed
382
383
384
385
386
387
388
389
390
391
392
393
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
394

395
    def run(self, verbose: bool = True) -> Coordinates:
396
397
398
399
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
400
        angles = None
401
402

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
403
            distances = self.get_distances(tables, verbose)
404

lucas_miranda's avatar
lucas_miranda committed
405
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
406
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
407

lucas_miranda's avatar
lucas_miranda committed
408
        if verbose:
409
410
411
            print("Done!")

        return coordinates(
412
413
            angles=angles,
            animal_ids=self.animal_ids,
lucas_miranda's avatar
lucas_miranda committed
414
415
416
            arena=self.arena,
            arena_dims=self.arena_dims,
            distances=distances,
417
            exp_conditions=self.exp_conditions,
418
            path=self.path,
419
420
421
422
            quality=quality,
            scales=self.scales,
            tables=tables,
            videos=self.videos,
423
424
        )

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    @subset_condition.setter
    def subset_condition(self, value):
        self._subset_condition = value

    @distances.setter
    def distances(self, value):
        self._distances = value

    @ego.setter
    def ego(self, value):
        self._ego = value

    @angles.setter
    def angles(self, value):
        self._angles = value

441
442

class coordinates:
443
444
445
446
447
448
449
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

450
    def __init__(
451
        self,
452
453
        arena: str,
        arena_dims: np.array,
454
        path: str,
455
456
457
458
        quality: dict,
        scales: np.array,
        tables: dict,
        videos: list,
459
        angles: dict = None,
460
        animal_ids: List = tuple([""]),
461
462
        distances: dict = None,
        exp_conditions: dict = None,
463
    ):
464
        self._animal_ids = animal_ids
465
466
        self._arena = arena
        self._arena_dims = arena_dims
467
        self._exp_conditions = exp_conditions
468
        self._path = path
469
470
471
472
473
474
        self._quality = quality
        self._scales = scales
        self._tables = tables
        self._videos = videos
        self.angles = angles
        self.distances = distances
475
476
477
478

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
479
                len(self._videos), len(set(self._exp_conditions.values()))
480
481
            )
        else:
482
            return "deepof analysis of {} videos".format(len(self._videos))
483

484
    def get_coords(
485
486
487
488
489
490
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
491
        align_inplace: bool = False,
492
        propagate_labels: bool = False,
493
        propagate_annotations: Dict = False,
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).
509
510
                - align_inplace (bool): Only valid if align is set. Aligns the vector that goes from the origin to
                the selected body part with the y axis, for all time points.
511
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
512
513
514
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
                are propagated through the training dataset. This can be used for initialising the weights of the
                clusters in the latent space, in a way that each cluster is related to a different annotation.
515
516
517
518
519

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

520
        tabs = deepof.utils.deepcopy(self._tables)
521

522
523
        if polar:
            for key, tab in tabs.items():
524
                tabs[key] = deepof.utils.tab2polar(tab)
525

526
        if center == "arena":
527
            if self._arena == "circular":
528

529
530
                for i, (key, value) in enumerate(tabs.items()):

531
532
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
533
534
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
535
536
537
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
538
539
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
540
541
542
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
543
544
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
545
546
547
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
548
549
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
550
                        )
551

552
        elif isinstance(center, str) and center != "arena":
553
554
555

            for i, (key, value) in enumerate(tabs.items()):

556
557
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
558
559
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
560
561

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
562
563
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
564
565
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
566
567
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
568
569

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
570
571
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
572

573
                tabs[key] = value.loc[
574
575
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
576

lucas_miranda's avatar
lucas_miranda committed
577
        if speed:
lucas_miranda's avatar
lucas_miranda committed
578
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
579
                vel = deepof.utils.rolling_speed(tab, deriv=speed, center=center)
lucas_miranda's avatar
lucas_miranda committed
580
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
581

582
583
584
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
585
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
586
                ).asisinstance("timedelta64[s]")
587

588
589
        if align:
            assert (
590
                align in list(tabs.values())[0].columns.levels[0]
591
592
593
594
595
596
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
597
598
599
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
600
601
602
603
604
605
606
607
608
609
                tab = tab[columns]
                tabs[key] = tab

                if align_inplace and polar is False:
                    index = tab.columns
                    tab = pd.DataFrame(
                        deepof.utils.align_trajectories(np.array(tab), mode="all")
                    )
                    tab.columns = index
                    tabs[key] = tab
610

611
612
        if propagate_labels:
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
613
                tab.loc[:, "pheno"] = self._exp_conditions[key]
614

615
616
617
618
619
620
621
        if propagate_annotations:
            annotations = list(propagate_annotations.values())[0].columns

            for key, tab in tabs.items():
                for ann in annotations:
                    tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

622
623
624
625
626
627
628
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
629
            propagate_labels=propagate_labels,
630
            propagate_annotations=propagate_annotations,
631
632
        )

633
    def get_distances(
634
635
636
637
638
        self,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
        propagate_annotations: Dict = False,
639
    ) -> Table_dict:
640
641
642
643
644
645
646
647
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
648
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
649
650
651
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
                are propagated through the training dataset. This can be used for initialising the weights of the
                clusters in the latent space, in a way that each cluster is related to a different annotation.
652
653
654
655

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
656

657
        tabs = deepof.utils.deepcopy(self.distances)
lucas_miranda's avatar
lucas_miranda committed
658

lucas_miranda's avatar
lucas_miranda committed
659
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
660
661

            if speed:
lucas_miranda's avatar
lucas_miranda committed
662
                for key, tab in tabs.items():
663
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
664
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
665

666
667
668
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
669
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
670
                    ).asisinstance("timedelta64[s]")
671

672
673
            if propagate_labels:
                for key, tab in tabs.items():
674
                    tab.loc[:, "pheno"] = self._exp_conditions[key]
675

676
677
678
679
680
681
682
683
684
685
            if propagate_annotations:
                annotations = list(propagate_annotations.values())[0].columns

                for key, tab in tabs.items():
                    for ann in annotations:
                        tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

            return table_dict(
                tabs,
                propagate_labels=propagate_labels,
686
                propagate_annotations=propagate_annotations,
687
688
                typ="dists",
            )
lucas_miranda's avatar
lucas_miranda committed
689

690
691
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
692
        )  # pragma: no cover
693

694
    def get_angles(
695
696
697
698
699
        self,
        degrees: bool = False,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
700
        propagate_annotations: Dict = False,
701
702
703
704
705
706
707
708
709
710
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
711
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
712
713
714
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
                are propagated through the training dataset. This can be used for initialising the weights of the
                clusters in the latent space, in a way that each cluster is related to a different annotation.
715
716
717
718

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
719

720
        tabs = deepof.utils.deepcopy(self.angles)
lucas_miranda's avatar
lucas_miranda committed
721

lucas_miranda's avatar
lucas_miranda committed
722
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
723
724
725
726
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
727
                for key, tab in tabs.items():
728
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
729
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
730

731
732
733
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
734
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
735
                    ).asisinstance("timedelta64[s]")
736

737
738
739
740
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

741
742
743
744
745
746
747
748
749
750
            if propagate_annotations:
                annotations = list(propagate_annotations.values())[0].columns

                for key, tab in tabs.items():
                    for ann in annotations:
                        tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

            return table_dict(
                tabs,
                propagate_labels=propagate_labels,
751
                propagate_annotations=propagate_annotations,
752
753
                typ="angles",
            )
lucas_miranda's avatar
lucas_miranda committed
754

755
756
757
        raise ValueError(
            "Angles not computed. Read the documentation for more details"
        )  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
758

759
760
761
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

762
        if play:  # pragma: no cover
763
764
765
766
767
768
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
769
770
        """Returns the stored dictionary with experimental conditions per subject"""

771
772
        return self._exp_conditions

773
    def get_quality(self):
774
775
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

776
777
778
779
        return self._quality

    @property
    def get_arenas(self):
780
781
        """Retrieves all available information associated with the arena"""

782
783
        return self._arena, self._arena_dims, self._scales

784
    # noinspection PyDefaultArgument
785
    def rule_based_annotation(
786
        self,
787
        params: Dict = {},
788
789
790
        video_output: bool = False,
        frame_limit: int = np.inf,
        debug: bool = False,
791
    ) -> Table_dict:
792
793
794
        """Annotates coordinates using a simple rule-based pipeline"""

        tag_dict = {}
lucas_miranda's avatar
lucas_miranda committed
795
796
        # noinspection PyTypeChecker
        coords = self.get_coords(center=False)
797
        dists = self.get_distances()
798
        speeds = self.get_coords(speed=1)
799

800
801
        for key in tqdm(self._tables.keys()):

802
            video = [vid for vid in self._videos if key + "DLC" in vid][0]
803
804
805
            tag_dict[key] = deepof.pose_utils.rule_based_tagging(
                list(self._tables.keys()),
                self._videos,
806
                self,
807
                coords,
808
                dists,
809
810
                speeds,
                self._videos.index(video),
811
                arena_type=self._arena,
812
                recog_limit=1,
813
                path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
814
                params=params,
815
            )
816
817
818

        if video_output:  # pragma: no cover

819
820
821
            def output_video(idx):
                """Outputs a single annotated video. Enclosed in a function to enable parallelization"""

822
823
824
825
826
827
                deepof.pose_utils.rule_based_video(
                    self,
                    list(self._tables.keys()),
                    self._videos,
                    list(self._tables.keys()).index(idx),
                    tag_dict[idx],
828
                    debug=debug,
829
830
831
                    frame_limit=frame_limit,
                    recog_limit=1,
                    path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
832
                    params=params,
833
                )
lucas_miranda's avatar
lucas_miranda committed
834
                pbar.update(1)
835

836
            if isinstance(video_output, list):
837
838
839
840
841
842
843
844
                vid_idxs = video_output
            elif video_output == "all":
                vid_idxs = list(self._tables.keys())
            else:
                raise AttributeError(
                    "Video output must be either 'all' or a list with the names of the videos to render"
                )

lucas_miranda's avatar
lucas_miranda committed
845
            njobs = cpu_count() // 2
lucas_miranda's avatar
lucas_miranda committed
846
            pbar = tqdm(total=len(vid_idxs))
847
848
            with parallel_backend("threading", n_jobs=njobs):
                Parallel()(delayed(output_video)(key) for key in vid_idxs)
lucas_miranda's avatar
lucas_miranda committed
849
            pbar.close()
850

851
852
853
        return table_dict(
            tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
        )
854

855
856
    @staticmethod
    def deep_unsupervised_embedding(
857
        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
858
        batch_size: int = 256,
859
        encoding_size: int = 4,
860
        epochs: int = 35,
861
862
        hparams: dict = None,
        kl_warmup: int = 0,
863
864
        log_history: bool = True,
        log_hparams: bool = False,
865
866
867
868
        loss: str = "ELBO",
        mmd_warmup: int = 0,
        montecarlo_kl: int = 10,
        n_components: int = 25,
869
        output_path: str = ".",
870
871
872
        phenotype_class: float = 0,
        predictor: float = 0,
        pretrained: str = False,
873
        save_checkpoints: bool = False,
874
        save_weights: bool = True,
875
        variational: bool = True,
876
877
        reg_cat_clusters: bool = False,
        reg_cluster_variance: bool = False,
878
        entropy_samples: int = 10000,
879
        entropy_knn: int = 100,
880
881
    ) -> Tuple:
        """
882
883
        Annotates coordinates using an unsupervised autoencoder.
        Full implementation in deepof.train_utils.deep_unsupervised_embedding
884
885
886
887
888

        Parameters:
            - preprocessed_object (Tuple[np.ndarray]): tuple containing a preprocessed object (X_train,
            y_train, X_test, y_test)
            - encoding_size (int): number of dimensions in the latent space of the autoencoder
889
            - epochs (int): epochs during which to train the models
890
            - batch_size (int): training batch size
891
            - save_checkpoints (bool): if True, training checkpoints are saved to disk. Useful for debugging,
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
            but can make training significantly slower
            - hparams (dict): dictionary to change architecture hyperparameters of the autoencoders
            (see documentation for details)
            - kl_warmup (int): number of epochs over which to increase KL weight linearly
            (default is number of epochs // 4)
            - loss (str): Loss function to use. Currently, 'ELBO', 'MMD' and 'ELBO+MMD' are supported.
            - mmd_warmup (int): number of epochs over which to increase MMD weight linearly
            (default is number of epochs // 4)
            - montecarlo_kl (int): Number of Montecarlo samples used to estimate the KL between latent space and prior
            - n_components (int): Number of components of the Gaussian Mixture in the latent space
            - outpath (str): Path where to save the training loggings
            - phenotype_class (float): weight assigned to phenotype classification. If > 0,
            a classification neural network is appended to the latent space,
            aiming to enforce structure from a set of labels in the encoding.
            - predictor (float): weight assigned to a predictor branch. If > 0, a regression neural network
            is appended to the latent space,
            aiming to predict what happens immediately next in the sequence, which can help with regularization.
            - pretrained (bool): If True, a pretrained set of weights is expected.
            - variational (bool): If True (default) a variational autoencoder is used. If False,
            a simple autoencoder is used for dimensionality reduction

        Returns:
            - return_list (tuple): List containing all relevant trained models for unsupervised prediction.

        """
917

918
        trained_models = deepof.train_utils.autoencoder_fitting(
919
920
921
            preprocessed_object=preprocessed_object,
            batch_size=batch_size,
            encoding_size=encoding_size,
922
            epochs=epochs,
923
924
925
926
927
928
929
930
931
932
933
934
935
            hparams=hparams,
            kl_warmup=kl_warmup,
            log_history=log_history,
            log_hparams=log_hparams,
            loss=loss,
            mmd_warmup=mmd_warmup,
            montecarlo_kl=montecarlo_kl,
            n_components=n_components,
            output_path=output_path,
            phenotype_class=phenotype_class,
            predictor=predictor,
            pretrained=pretrained,
            save_checkpoints=save_checkpoints,
936
            save_weights=save_weights,
937
            variational=variational,
938
939
            reg_cat_clusters=reg_cat_clusters,
            reg_cluster_variance=reg_cluster_variance,
940
            entropy_samples=entropy_samples,
941
            entropy_knn=entropy_knn,
942
        )
943
944

        # returns a list of trained tensorflow models
945
        return trained_models
946

947
948

class table_dict(dict):
949
950
951
952
953
954
955
956
957
    """

    Main class for storing a single dataset as a dictionary with individuals as keys and pandas.DataFrames as values.
    Includes methods for generating training and testing datasets for the autoencoders.

    """

    def __init__(
        self,
958
        tabs: Dict,
959
960
961
962
963
        typ: str,
        arena: str = None,
        arena_dims: np.array = None,
        center: str = None,
        polar: bool = None,
964
        propagate_labels: bool = False,
965
        propagate_annotations: Dict = False,
966