data.py 47.8 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

17
from collections import defaultdict
18
from joblib import delayed, Parallel, parallel_backend
19
from typing import Dict, List, Tuple, Union
lucas_miranda's avatar
lucas_miranda committed
20
from multiprocessing import cpu_count
21
22
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
23
24
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
25
from sklearn.manifold import TSNE
26
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
27
from tqdm import tqdm
28
import deepof.models
29
30
31
import deepof.pose_utils
import deepof.utils
import deepof.visuals
32
import deepof.train_utils
33
34
import matplotlib.pyplot as plt
import numpy as np
35
import os
36
37
import pandas as pd
import warnings
38

39
40
# Remove excessive logging from tensorflow
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
41

42
# DEFINE CUSTOM ANNOTATED TYPES #
43
44
Coordinates = deepof.utils.NewType("Coordinates", deepof.utils.Any)
Table_dict = deepof.utils.NewType("Table_dict", deepof.utils.Any)
45
46

# CLASSES FOR PREPROCESSING AND DATA WRANGLING
47

48

49
class project:
lucas_miranda's avatar
lucas_miranda committed
50
51
    """

52
53
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
54
55

    """
56
57

    def __init__(
58
        self,
59
        animal_ids: List = tuple([""]),
60
61
        arena: str = "circular",
        arena_dims: tuple = (1,),
62
        enable_iterative_imputation: bool = None,
63
64
        exclude_bodyparts: List = tuple([""]),
        exp_conditions: dict = None,
65
        interpolate_outliers: bool = True,
66
        interpolation_limit: int = 5,
lucas_miranda's avatar
lucas_miranda committed
67
        interpolation_std: int = 5,
68
        likelihood_tol: float = 0.25,
69
        model: str = "mouse_topview",
70
71
        path: str = deepof.utils.os.path.join("."),
        smooth_alpha: float = 0.99,
72
        table_format: str = "autodetect",
73
        video_format: str = ".mp4",
74
    ):
lucas_miranda's avatar
lucas_miranda committed
75

76
        self.path = path
77
78
        self.video_path = os.path.join(self.path, "Videos")
        self.table_path = os.path.join(self.path, "Tables")
79

80
        self.table_format = table_format
81
        if self.table_format == "autodetect":
82
            ex = [i for i in os.listdir(self.table_path) if not i.startswith(".")][0]
83
84
85
86
            if ".h5" in ex:
                self.table_format = ".h5"
            elif ".csv" in ex:
                self.table_format = ".csv"
87

88
        self.videos = sorted(
89
90
91
            [
                vid
                for vid in deepof.utils.os.listdir(self.video_path)
92
                if vid.endswith(video_format) and not vid.startswith(".")
93
            ]
94
95
        )
        self.tables = sorted(
96
97
98
            [
                tab
                for tab in deepof.utils.os.listdir(self.table_path)
99
                if tab.endswith(self.table_format) and not tab.startswith(".")
100
            ]
101
        )
102
        self.angles = True
103
        self.animal_ids = animal_ids
104
105
        self.arena = arena
        self.arena_dims = arena_dims
106
107
        self.distances = "all"
        self.ego = False
108
        self.exp_conditions = exp_conditions
109
110
111
112
        self.interpolate_outliers = interpolate_outliers
        self.interpolation_limit = interpolation_limit
        self.interpolation_std = interpolation_std
        self.likelihood_tolerance = likelihood_tol
113
        self.scales = self.get_scale
114
115
        self.smooth_alpha = smooth_alpha
        self.subset_condition = None
116
        self.video_format = video_format
117
118
119
120
        if enable_iterative_imputation is None:
            self.enable_iterative_imputation = self.animal_ids == tuple([""])
        else:
            self.enable_iterative_imputation = enable_iterative_imputation
121

lucas_miranda's avatar
lucas_miranda committed
122
123
124
        model_dict = {
            "mouse_topview": deepof.utils.connect_mouse_topview(animal_ids[0])
        }
lucas_miranda's avatar
lucas_miranda committed
125
        self.connectivity = model_dict[model]
126
127
128
129
        self.exclude_bodyparts = exclude_bodyparts
        if self.exclude_bodyparts != tuple([""]):
            for bp in exclude_bodyparts:
                self.connectivity.remove_node(bp)
lucas_miranda's avatar
lucas_miranda committed
130

131
132
    def __str__(self):
        if self.exp_conditions:
133
            return "deepof analysis of {} videos across {} conditions".format(
134
                len(self.videos), len(set(self.exp_conditions.values()))
135
136
            )
        else:
137
            return "deepof analysis of {} videos".format(len(self.videos))
138

139
140
141
142
143
144
145
146
    @property
    def subset_condition(self):
        """Sets a subset condition for the videos to load. If set,
        only the videos with the included pattern will be loaded"""
        return self._subset_condition

    @property
    def distances(self):
147
        """List. If not 'all', sets the body parts among which the
148
149
150
151
152
153
154
155
156
157
158
159
        distances will be computed"""
        return self._distances

    @property
    def ego(self):
        """String, name of a body part. If True, computes only the distances
        between the specified body part and the rest"""
        return self._ego

    @property
    def angles(self):
        """Bool. Toggles angle computation. True by default. If turned off,
160
        enhances performance for big datasets"""
161
162
163
164
165
166
167
168
169
170
        return self._angles

    @property
    def get_scale(self) -> np.array:
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

            scales = []
            for vid_index, _ in enumerate(self.videos):
171
172
173
174
175
176
177
178

                ellipse = deepof.utils.recognize_arena(
                    self.videos,
                    vid_index,
                    path=self.video_path,
                    arena_type=self.arena,
                )[0]

179
                scales.append(
180
                    list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1]]) * 2)
181
182
183
184
185
186
187
188
                    + list(self.arena_dims)
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

189
    def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
190
191
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
192
193
194
195
196
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

197
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
198
            print("Loading trajectories...")
199

lucas_miranda's avatar
lucas_miranda committed
200
201
        tab_dict = {}

202
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
203
204

            tab_dict = {
205
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_hdf(
206
                    deepof.utils.os.path.join(self.table_path, tab), dtype=float
207
208
209
210
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
211

212
213
            tab_dict = {
                deepof.utils.re.findall("(.*)DLC", tab)[0]: pd.read_csv(
214
                    deepof.utils.os.path.join(self.table_path, tab),
215
216
217
                    header=[0, 1, 2],
                    index_col=0,
                    dtype=float,
218
                )
219
220
                for tab in self.tables
            }
221
222
223

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
224
        for key, value in tab_dict.items():
225
226
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
227
            lik = value.xs("likelihood", level="coords", axis=1, drop_level=True)
228

lucas_miranda's avatar
lucas_miranda committed
229
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
230
            lik_dict[key] = lik.droplevel("scorer", axis=1)
231
232
233

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
234
235
236
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
237
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
238
239
                cols = tab.columns
                smooth = pd.DataFrame(
240
241
242
                    deepof.utils.smooth_mult_trajectory(
                        np.array(tab), alpha=self.smooth_alpha
                    )
243
                )
lucas_miranda's avatar
lucas_miranda committed
244
                smooth.columns = cols
245
                tab_dict[key] = smooth.iloc[1:, :].reset_index(drop=True)
246

lucas_miranda's avatar
lucas_miranda committed
247
        for key, tab in tab_dict.items():
248
            tab_dict[key] = tab.loc[:, tab.columns.levels[0][0]]
249

250
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
251
            for key, value in tab_dict.items():
252
253
254
255
256
257
258
259
260
261
262
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
263
264
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
265
266
267

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
268
                tab_dict[key] = tab
269

270
271
272
273
274
275
276
277
278
279
        if self.exclude_bodyparts != tuple([""]):

            for k, value in tab_dict.items():
                temp = value.drop(self.exclude_bodyparts, axis=1, level="bodyparts")
                temp.sort_index(axis=1, inplace=True)
                temp.columns = pd.MultiIndex.from_product(
                    [sorted(list(set([i[j] for i in temp.columns]))) for j in range(2)]
                )
                tab_dict[k] = temp.sort_index(axis=1)

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        if self.interpolate_outliers:

            if verbose:
                print("Interpolating outliers...")

            for k, value in tab_dict.items():
                tab_dict[k] = deepof.utils.interpolate_outliers(
                    value,
                    lik_dict[k],
                    likelihood_tolerance=self.likelihood_tolerance,
                    mode="or",
                    limit=self.interpolation_limit,
                    n_std=self.interpolation_std,
                )

295
296
297
298
299
300
        if self.enable_iterative_imputation:

            if verbose:
                print("Iterative imputation of ocluded bodyparts...")

            for k, value in tab_dict.items():
301
                imputed = IterativeImputer(
302
                    max_iter=1000, skip_complete=True
303
                ).fit_transform(value)
304
305
306
307
                tab_dict[k] = pd.DataFrame(
                    imputed, index=value.index, columns=value.columns
                )

lucas_miranda's avatar
lucas_miranda committed
308
        return tab_dict, lik_dict
309

310
311
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
312
        If ego is provided, it only returns distances to a specified bodypart"""
313

lucas_miranda's avatar
lucas_miranda committed
314
315
316
        if verbose:
            print("Computing distances...")

317
        nodes = self.distances
318
        if nodes == "all":
lucas_miranda's avatar
lucas_miranda committed
319
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
320
321

        assert [
lucas_miranda's avatar
lucas_miranda committed
322
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
323
324
325
326
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

327
        distance_dict = {
328
329
330
331
332
            key: deepof.utils.bpart_distance(
                tab,
                scales[i, 1],
                scales[i, 0],
            )
lucas_miranda's avatar
lucas_miranda committed
333
            for i, (key, tab) in enumerate(tab_dict.items())
334
        }
335

lucas_miranda's avatar
lucas_miranda committed
336
337
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
338
339
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
340

341
342
343
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
344
345
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
346
347
348

        return distance_dict

349
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
350
351
352
353
354
355
356
357
358
359
360
361
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
362
363
364
        if verbose:
            print("Computing angles...")

365
        cliques = deepof.utils.nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
366
367
368
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
369
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
370
371
372
373

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
374
375
376
                    deepof.utils.angle_trio(
                        np.array(tab[clique]).reshape([3, tab.shape[0], 2])
                    )
lucas_miranda's avatar
lucas_miranda committed
377
378
379
380
381
382
383
384
385
386
387
388
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
389

390
    def run(self, verbose: bool = True) -> Coordinates:
391
392
393
394
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
395
        angles = None
396
397

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
398
            distances = self.get_distances(tables, verbose)
399

lucas_miranda's avatar
lucas_miranda committed
400
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
401
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
402

lucas_miranda's avatar
lucas_miranda committed
403
        if verbose:
404
405
406
            print("Done!")

        return coordinates(
407
408
            angles=angles,
            animal_ids=self.animal_ids,
lucas_miranda's avatar
lucas_miranda committed
409
410
411
            arena=self.arena,
            arena_dims=self.arena_dims,
            distances=distances,
412
            exp_conditions=self.exp_conditions,
413
            path=self.path,
414
415
416
417
            quality=quality,
            scales=self.scales,
            tables=tables,
            videos=self.videos,
418
419
        )

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    @subset_condition.setter
    def subset_condition(self, value):
        self._subset_condition = value

    @distances.setter
    def distances(self, value):
        self._distances = value

    @ego.setter
    def ego(self, value):
        self._ego = value

    @angles.setter
    def angles(self, value):
        self._angles = value

436
437

class coordinates:
438
439
440
441
442
443
444
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

445
    def __init__(
446
        self,
447
448
        arena: str,
        arena_dims: np.array,
449
        path: str,
450
451
452
453
        quality: dict,
        scales: np.array,
        tables: dict,
        videos: list,
454
        angles: dict = None,
455
        animal_ids: List = tuple([""]),
456
457
        distances: dict = None,
        exp_conditions: dict = None,
458
    ):
459
        self._animal_ids = animal_ids
460
461
        self._arena = arena
        self._arena_dims = arena_dims
462
        self._exp_conditions = exp_conditions
463
        self._path = path
464
465
466
467
468
469
        self._quality = quality
        self._scales = scales
        self._tables = tables
        self._videos = videos
        self.angles = angles
        self.distances = distances
470
471
472
473

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
474
                len(self._videos), len(set(self._exp_conditions.values()))
475
476
            )
        else:
477
            return "deepof analysis of {} videos".format(len(self._videos))
478

479
    def get_coords(
480
481
482
483
484
485
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
486
        align_inplace: bool = False,
487
        propagate_labels: bool = False,
488
        propagate_annotations: Dict = False,
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).
504
505
                - align_inplace (bool): Only valid if align is set. Aligns the vector that goes from the origin to
                the selected body part with the y axis, for all time points.
506
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
507
508
509
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
                are propagated through the training dataset. This can be used for initialising the weights of the
                clusters in the latent space, in a way that each cluster is related to a different annotation.
510
511
512
513
514

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

515
        tabs = deepof.utils.deepcopy(self._tables)
516

517
518
        if polar:
            for key, tab in tabs.items():
519
                tabs[key] = deepof.utils.tab2polar(tab)
520

521
        if center == "arena":
522
            if self._arena == "circular":
523

524
525
                for i, (key, value) in enumerate(tabs.items()):

526
527
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
528
529
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
530
531
532
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
533
534
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
535
536
537
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
538
539
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
540
541
542
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
543
544
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
545
                        )
546

547
548
549
550
        elif type(center) == str and center != "arena":

            for i, (key, value) in enumerate(tabs.items()):

551
552
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
553
554
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
555
556

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
557
558
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
559
560
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
561
562
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
563
564

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
565
566
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
567

568
                tabs[key] = value.loc[
569
570
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
571

lucas_miranda's avatar
lucas_miranda committed
572
        if speed:
lucas_miranda's avatar
lucas_miranda committed
573
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
574
                vel = deepof.utils.rolling_speed(tab, deriv=speed, center=center)
lucas_miranda's avatar
lucas_miranda committed
575
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
576

577
578
579
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
580
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
581
                ).astype("timedelta64[s]")
582

583
584
        if align:
            assert (
585
                align in list(tabs.values())[0].columns.levels[0]
586
587
588
589
590
591
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
592
593
594
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
595
596
597
598
599
600
601
602
603
604
                tab = tab[columns]
                tabs[key] = tab

                if align_inplace and polar is False:
                    index = tab.columns
                    tab = pd.DataFrame(
                        deepof.utils.align_trajectories(np.array(tab), mode="all")
                    )
                    tab.columns = index
                    tabs[key] = tab
605

606
607
        if propagate_labels:
            for key, tab in tabs.items():
lucas_miranda's avatar
lucas_miranda committed
608
                tab.loc[:, "pheno"] = self._exp_conditions[key]
609

610
611
612
613
614
615
616
        if propagate_annotations:
            annotations = list(propagate_annotations.values())[0].columns

            for key, tab in tabs.items():
                for ann in annotations:
                    tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

617
618
619
620
621
622
623
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
624
            propagate_labels=propagate_labels,
625
            propagate_annotations=propagate_annotations,
626
627
        )

628
    def get_distances(
629
630
631
632
633
        self,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
        propagate_annotations: Dict = False,
634
    ) -> Table_dict:
635
636
637
638
639
640
641
642
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
643
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
644
645
646
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
                are propagated through the training dataset. This can be used for initialising the weights of the
                clusters in the latent space, in a way that each cluster is related to a different annotation.
647
648
649
650

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
651

652
        tabs = deepof.utils.deepcopy(self.distances)
lucas_miranda's avatar
lucas_miranda committed
653

lucas_miranda's avatar
lucas_miranda committed
654
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
655
656

            if speed:
lucas_miranda's avatar
lucas_miranda committed
657
                for key, tab in tabs.items():
658
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="dists")
lucas_miranda's avatar
lucas_miranda committed
659
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
660

661
662
663
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
664
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
665
                    ).astype("timedelta64[s]")
666

667
668
            if propagate_labels:
                for key, tab in tabs.items():
669
                    tab.loc[:, "pheno"] = self._exp_conditions[key]
670

671
672
673
674
675
676
677
678
679
680
            if propagate_annotations:
                annotations = list(propagate_annotations.values())[0].columns

                for key, tab in tabs.items():
                    for ann in annotations:
                        tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

            return table_dict(
                tabs,
                propagate_labels=propagate_labels,
681
                propagate_annotations=propagate_annotations,
682
683
                typ="dists",
            )
lucas_miranda's avatar
lucas_miranda committed
684

685
686
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
687
        )  # pragma: no cover
688

689
    def get_angles(
690
691
692
693
694
        self,
        degrees: bool = False,
        speed: int = 0,
        length: str = None,
        propagate_labels: bool = False,
695
        propagate_annotations: Dict = False,
696
697
698
699
700
701
702
703
704
705
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
706
                - propagate_labels (bool): If True, adds an extra feature for each video containing its phenotypic label
707
708
709
                - propagate_annotations (Dict): if a dictionary is provided, rule based annotations
                are propagated through the training dataset. This can be used for initialising the weights of the
                clusters in the latent space, in a way that each cluster is related to a different annotation.
710
711
712
713

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
714

715
        tabs = deepof.utils.deepcopy(self.angles)
lucas_miranda's avatar
lucas_miranda committed
716

lucas_miranda's avatar
lucas_miranda committed
717
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
718
719
720
721
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
722
                for key, tab in tabs.items():
723
                    vel = deepof.utils.rolling_speed(tab, deriv=speed + 1, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
724
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
725

726
727
728
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
729
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
lucas_miranda's avatar
lucas_miranda committed
730
                    ).astype("timedelta64[s]")
731

732
733
734
735
            if propagate_labels:
                for key, tab in tabs.items():
                    tab["pheno"] = self._exp_conditions[key]

736
737
738
739
740
741
742
743
744
745
            if propagate_annotations:
                annotations = list(propagate_annotations.values())[0].columns

                for key, tab in tabs.items():
                    for ann in annotations:
                        tab.loc[:, ann] = propagate_annotations[key].loc[:, ann]

            return table_dict(
                tabs,
                propagate_labels=propagate_labels,
746
                propagate_annotations=propagate_annotations,
747
748
                typ="angles",
            )
lucas_miranda's avatar
lucas_miranda committed
749

750
751
752
        raise ValueError(
            "Angles not computed. Read the documentation for more details"
        )  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
753

754
755
756
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

757
        if play:  # pragma: no cover
758
759
760
761
762
763
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
764
765
        """Returns the stored dictionary with experimental conditions per subject"""

766
767
        return self._exp_conditions

768
    def get_quality(self):
769
770
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

771
772
773
774
        return self._quality

    @property
    def get_arenas(self):
775
776
        """Retrieves all available information associated with the arena"""

777
778
        return self._arena, self._arena_dims, self._scales

779
    # noinspection PyDefaultArgument
780
    def rule_based_annotation(
781
        self,
782
        params: Dict = {},
783
784
785
        video_output: bool = False,
        frame_limit: int = np.inf,
        debug: bool = False,
786
    ) -> Table_dict:
787
788
789
        """Annotates coordinates using a simple rule-based pipeline"""

        tag_dict = {}
lucas_miranda's avatar
lucas_miranda committed
790
791
        # noinspection PyTypeChecker
        coords = self.get_coords(center=False)
792
        dists = self.get_distances()
793
        speeds = self.get_coords(speed=1)
794

795
796
        for key in tqdm(self._tables.keys()):

797
            video = [vid for vid in self._videos if key + "DLC" in vid][0]
798
799
800
            tag_dict[key] = deepof.pose_utils.rule_based_tagging(
                list(self._tables.keys()),
                self._videos,
801
                self,
802
                coords,
803
                dists,
804
805
                speeds,
                self._videos.index(video),
806
                arena_type=self._arena,
807
                recog_limit=1,
808
                path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
809
                params=params,
810
            )
811
812
813

        if video_output:  # pragma: no cover

814
815
816
            def output_video(idx):
                """Outputs a single annotated video. Enclosed in a function to enable parallelization"""

817
818
819
820
821
822
                deepof.pose_utils.rule_based_video(
                    self,
                    list(self._tables.keys()),
                    self._videos,
                    list(self._tables.keys()).index(idx),
                    tag_dict[idx],
823
                    debug=debug,
824
825
826
                    frame_limit=frame_limit,
                    recog_limit=1,
                    path=os.path.join(self._path, "Videos"),
lucas_miranda's avatar
lucas_miranda committed
827
                    params=params,
828
                )
lucas_miranda's avatar
lucas_miranda committed
829
                pbar.update(1)
830

831
832
833
834
835
836
837
838
839
            if type(video_output) == list:
                vid_idxs = video_output
            elif video_output == "all":
                vid_idxs = list(self._tables.keys())
            else:
                raise AttributeError(
                    "Video output must be either 'all' or a list with the names of the videos to render"
                )

lucas_miranda's avatar
lucas_miranda committed
840
            njobs = cpu_count() // 2
lucas_miranda's avatar
lucas_miranda committed
841
            pbar = tqdm(total=len(vid_idxs))
842
843
            with parallel_backend("threading", n_jobs=njobs):
                Parallel()(delayed(output_video)(key) for key in vid_idxs)
lucas_miranda's avatar
lucas_miranda committed
844
            pbar.close()
845

846
847
848
        return table_dict(
            tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
        )
849

850
851
    @staticmethod
    def deep_unsupervised_embedding(
852
        preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
853
        batch_size: int = 256,
854
        encoding_size: int = 4,
855
        epochs: int = 35,
856
857
        hparams: dict = None,
        kl_warmup: int = 0,
858
859
        log_history: bool = True,
        log_hparams: bool = False,
860
861
862
863
        loss: str = "ELBO",
        mmd_warmup: int = 0,
        montecarlo_kl: int = 10,
        n_components: int = 25,
864
        output_path: str = ".",
865
866
867
        phenotype_class: float = 0,
        predictor: float = 0,
        pretrained: str = False,
868
        save_checkpoints: bool = False,
869
        save_weights: bool = True,
870
        variational: bool = True,
871
872
        reg_cat_clusters: bool = False,
        reg_cluster_variance: bool = False,
873
        entropy_samples: int = 10000,
874
        entropy_min_n: int = 5,
875
876
    ) -> Tuple:
        """
877
878
        Annotates coordinates using an unsupervised autoencoder.
        Full implementation in deepof.train_utils.deep_unsupervised_embedding
879
880
881
882
883

        Parameters:
            - preprocessed_object (Tuple[np.ndarray]): tuple containing a preprocessed object (X_train,
            y_train, X_test, y_test)
            - encoding_size (int): number of dimensions in the latent space of the autoencoder
884
            - epochs (int): epochs during which to train the models
885
            - batch_size (int): training batch size
886
            - save_checkpoints (bool): if True, training checkpoints are saved to disk. Useful for debugging,
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
            but can make training significantly slower
            - hparams (dict): dictionary to change architecture hyperparameters of the autoencoders
            (see documentation for details)
            - kl_warmup (int): number of epochs over which to increase KL weight linearly
            (default is number of epochs // 4)
            - loss (str): Loss function to use. Currently, 'ELBO', 'MMD' and 'ELBO+MMD' are supported.
            - mmd_warmup (int): number of epochs over which to increase MMD weight linearly
            (default is number of epochs // 4)
            - montecarlo_kl (int): Number of Montecarlo samples used to estimate the KL between latent space and prior
            - n_components (int): Number of components of the Gaussian Mixture in the latent space
            - outpath (str): Path where to save the training loggings
            - phenotype_class (float): weight assigned to phenotype classification. If > 0,
            a classification neural network is appended to the latent space,
            aiming to enforce structure from a set of labels in the encoding.
            - predictor (float): weight assigned to a predictor branch. If > 0, a regression neural network
            is appended to the latent space,
            aiming to predict what happens immediately next in the sequence, which can help with regularization.
            - pretrained (bool): If True, a pretrained set of weights is expected.
            - variational (bool): If True (default) a variational autoencoder is used. If False,
            a simple autoencoder is used for dimensionality reduction

        Returns:
            - return_list (tuple): List containing all relevant trained models for unsupervised prediction.

        """
912

913
        trained_models = deepof.train_utils.autoencoder_fitting(
914
915
916
            preprocessed_object=preprocessed_object,
            batch_size=batch_size,
            encoding_size=encoding_size,
917
            epochs=epochs,
918
919
920
921
922
923
924
925
926
927
928
929
930
            hparams=hparams,
            kl_warmup=kl_warmup,
            log_history=log_history,
            log_hparams=log_hparams,
            loss=loss,
            mmd_warmup=mmd_warmup,
            montecarlo_kl=montecarlo_kl,
            n_components=n_components,
            output_path=output_path,
            phenotype_class=phenotype_class,
            predictor=predictor,
            pretrained=pretrained,
            save_checkpoints=save_checkpoints,
931
            save_weights=save_weights,
932
            variational=variational,
933
934
            reg_cat_clusters=reg_cat_clusters,
            reg_cluster_variance=reg_cluster_variance,
935
            entropy_samples=entropy_samples,
936
            entropy_min_n=entropy_min_n,
937
        )
938
939

        # returns a list of trained tensorflow models
940
        return trained_models
941

942
943

class table_dict(dict):
944
945
946
947
948
949
950
951
952
    """

    Main class for storing a single dataset as a dictionary with individuals as keys and pandas.DataFrames as values.
    Includes methods for generating training and testing datasets for the autoencoders.

    """

    def __init__(
        self,
953
        tabs: Dict,
954
955
956
957
958
        typ: str,
        arena: str = None,
        arena_dims: np.array = None,
        center: str = None,
        polar: bool = None,
959
        propagate_labels: bool = False,
960
        propagate_annotations: Dict = False,
961
    ):
962
963
964
965
966
967
        super().__init__(tabs)
        self._type = typ
        self._center = center
        self._polar = polar
        self._arena = arena
        self._arena_dims = arena_dims
lucas_miranda's avatar