preprocess.py 30 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
# @author lucasmiranda42
2
3
# encoding: utf-8
# module deepof
lucas_miranda's avatar
lucas_miranda committed
4

5
6
7
8
9
10
11
12
13
14
15
16
17
"""

Data structures for preprocessing and wrangling of DLC output data.

- project: initial structure for specifying the characteristics of the project.
- coordinates: result of running the project. In charge of calling all relevant
computations for getting the data into the desired shape
- table_dict: python dict subclass for storing experimental instances as pandas.DataFrames.
Contains methods for generating training and test sets ready for model training.

"""

import warnings
18
from collections import defaultdict
19
20
21

from deepof.utils import *
from deepof.visuals import *
22
23
24
25
from pandas_profiling import ProfileReport
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
from sklearn.manifold import TSNE
26
from sklearn.preprocessing import MinMaxScaler, StandardScaler
27

28
29
30
31
32
33
# DEFINE CUSTOM ANNOTATED TYPES #

Coordinates = NewType("Coordinates", Any)
Table_dict = NewType("Table_dict", Any)

# CLASSES FOR PREPROCESSING AND DATA WRANGLING
34

35

36
class project:
lucas_miranda's avatar
lucas_miranda committed
37
38
    """

39
40
    Class for loading and preprocessing DLC data of individual and multiple animals. All main computations are called
    here.
lucas_miranda's avatar
lucas_miranda committed
41
42

    """
43
44

    def __init__(
45
        self,
46
47
48
49
50
51
52
53
54
55
56
57
        video_format: str = ".mp4",
        table_format: str = ".h5",
        path: str = ".",
        exp_conditions: dict = None,
        subset_condition: list = None,
        arena: str = "circular",
        smooth_alpha: float = 0.1,
        arena_dims: tuple = (1,),
        distances: str = "All",
        ego: str = False,
        angles: bool = True,
        model: str = "mouse_topview",
58
    ):
lucas_miranda's avatar
lucas_miranda committed
59

60
        self.path = path
61
62
        self.video_path = self.path + "/Videos/"
        self.table_path = self.path + "/Tables/"
63
64
65
66
67
68
69
        self.videos = sorted(
            [vid for vid in os.listdir(self.video_path) if vid.endswith(video_format)]
        )
        self.tables = sorted(
            [tab for tab in os.listdir(self.table_path) if tab.endswith(table_format)]
        )
        self.exp_conditions = exp_conditions
70
        self.subset_condition = subset_condition
71
72
73
74
75
76
77
        self.table_format = table_format
        self.video_format = video_format
        self.arena = arena
        self.arena_dims = arena_dims
        self.smooth_alpha = smooth_alpha
        self.distances = distances
        self.ego = ego
lucas_miranda's avatar
lucas_miranda committed
78
        self.angles = angles
79
80
        self.scales = self.get_scale

lucas_miranda's avatar
lucas_miranda committed
81
82
83
        model_dict = {"mouse_topview": connect_mouse_topview()}
        self.connectivity = model_dict[model]

84
85
86
87
88
89
90
91
    def __str__(self):
        if self.exp_conditions:
            return "DLC analysis of {} videos across {} conditions".format(
                len(self.videos), len(self.exp_conditions)
            )
        else:
            return "DLC analysis of {} videos".format(len(self.videos))

92
    def load_tables(self, verbose: bool = False) -> Tuple:
93
94
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
95
96
97
98
99
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

100
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
101
            print("Loading trajectories...")
102

lucas_miranda's avatar
lucas_miranda committed
103
104
        tab_dict = {}

105
        if self.table_format == ".h5":
lucas_miranda's avatar
lucas_miranda committed
106
107

            tab_dict = {
108
                re.findall("(.*?)_", tab)[0]: pd.read_hdf(
lucas_miranda's avatar
lucas_miranda committed
109
                    os.path.join(self.table_path, tab), dtype=float
110
111
112
113
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
lucas_miranda's avatar
lucas_miranda committed
114

lucas_miranda's avatar
lucas_miranda committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
            for tab in self.tables:
                head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2)
                data = pd.read_csv(
                    os.path.join(self.table_path, tab),
                    skiprows=2,
                    index_col="coords",
                    dtype={"coords": int},
                ).drop("1", axis=1)
                data.columns = pd.MultiIndex.from_product(
                    [
                        [head.columns[2]],
                        set(list(head.iloc[0])[2:]),
                        ["x", "y", "likelihood"],
                    ],
                    names=["scorer", "bodyparts", "coords"],
130
                )
lucas_miranda's avatar
lucas_miranda committed
131
                tab_dict[re.findall("(.*?)_", tab)[0]] = data
132
133
134

        lik_dict = defaultdict()

lucas_miranda's avatar
lucas_miranda committed
135
        for key, value in tab_dict.items():
136
137
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
lucas_miranda's avatar
lucas_miranda committed
138
139
140
            lik: pd.DataFrame = value.xs(
                "likelihood", level="coords", axis=1, drop_level=True
            )
141

lucas_miranda's avatar
lucas_miranda committed
142
            tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
lucas_miranda's avatar
lucas_miranda committed
143
            lik_dict[key] = lik
144
145
146

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
147
148
149
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
150
            for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
151
152
153
                cols = tab.columns
                smooth = pd.DataFrame(
                    smooth_mult_trajectory(np.array(tab), alpha=self.smooth_alpha)
154
                )
lucas_miranda's avatar
lucas_miranda committed
155
                smooth.columns = cols
lucas_miranda's avatar
lucas_miranda committed
156
                tab_dict[key] = smooth
157

lucas_miranda's avatar
lucas_miranda committed
158
159
        for key, tab in tab_dict.items():
            tab_dict[key] = tab[tab.columns.levels[0][0]]
160

161
        if self.subset_condition:
lucas_miranda's avatar
lucas_miranda committed
162
            for key, value in tab_dict.items():
163
164
165
166
167
168
169
170
171
172
173
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
174
175
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
176
177
178

                tab.columns = tabcols

lucas_miranda's avatar
lucas_miranda committed
179
                tab_dict[key] = tab
180

lucas_miranda's avatar
lucas_miranda committed
181
        return tab_dict, lik_dict
182
183

    @property
184
    def get_scale(self) -> np.array:
185
186
187
188
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

189
190
191
192
193
194
195
196
197
            scales = []
            for vid_index, _ in enumerate(self.videos):
                scales.append(
                    list(
                        recognize_arena(
                            self.videos,
                            vid_index,
                            path=self.video_path,
                            arena_type=self.arena,
198
                        )[0]
199
                        * 2
200
                    )
201
                    + list(self.arena_dims)
202
203
204
205
206
207
208
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

209
210
    def get_distances(self, tab_dict: dict, verbose: bool = False) -> dict:
        """Computes the distances between all selected body parts over time.
211
212
           If ego is provided, it only returns distances to a specified bodypart"""

lucas_miranda's avatar
lucas_miranda committed
213
214
215
        if verbose:
            print("Computing distances...")

216
217
        nodes = self.distances
        if nodes == "All":
lucas_miranda's avatar
lucas_miranda committed
218
            nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
219
220

        assert [
lucas_miranda's avatar
lucas_miranda committed
221
            i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
222
223
224
225
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

226
        distance_dict = {
227
            key: bpart_distance(tab, scales[i, 1], scales[i, 0],)
lucas_miranda's avatar
lucas_miranda committed
228
            for i, (key, tab) in enumerate(tab_dict.items())
229
        }
230

lucas_miranda's avatar
lucas_miranda committed
231
232
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
233
234
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
235

236
237
238
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
239
240
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
241
242
243

        return distance_dict

244
    def get_angles(self, tab_dict: dict, verbose: bool = False) -> dict:
lucas_miranda's avatar
lucas_miranda committed
245
246
247
248
249
250
251
252
253
254
255
256
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
257
258
259
        if verbose:
            print("Computing angles...")

lucas_miranda's avatar
lucas_miranda committed
260
        cliques = nx.enumerate_all_cliques(self.connectivity)
lucas_miranda's avatar
lucas_miranda committed
261
262
263
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
lucas_miranda's avatar
lucas_miranda committed
264
        for key, tab in tab_dict.items():
lucas_miranda's avatar
lucas_miranda committed
265
266
267
268

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
lucas_miranda's avatar
lucas_miranda committed
269
                    angle_trio(np.array(tab[clique]).reshape([3, tab.shape[0], 2]))
lucas_miranda's avatar
lucas_miranda committed
270
271
272
273
274
275
276
277
278
279
280
281
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
282

283
    def run(self, verbose: bool = False) -> Coordinates:
284
285
286
287
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
288
        angles = None
289
290

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
291
            distances = self.get_distances(tables, verbose)
292

lucas_miranda's avatar
lucas_miranda committed
293
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
294
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
295

lucas_miranda's avatar
lucas_miranda committed
296
        if verbose:
297
298
299
            print("Done!")

        return coordinates(
lucas_miranda's avatar
lucas_miranda committed
300
301
302
303
304
305
306
307
308
            tables=tables,
            videos=self.videos,
            arena=self.arena,
            arena_dims=self.arena_dims,
            scales=self.scales,
            quality=quality,
            exp_conditions=self.exp_conditions,
            distances=distances,
            angles=angles,
309
310
311
312
        )


class coordinates:
313
314
315
316
317
318
319
    """

    Class for storing the results of a ran project. Methods are mostly setters and getters in charge of tidying up
    the generated tables. For internal usage only.

    """

320
    def __init__(
321
        self,
322
323
324
325
326
327
328
329
330
        tables: dict,
        videos: list,
        arena: str,
        arena_dims: np.array,
        scales: np.array,
        quality: dict,
        exp_conditions: dict = None,
        distances: dict = None,
        angles: dict = None,
331
332
333
    ):
        self._tables = tables
        self.distances = distances
lucas_miranda's avatar
lucas_miranda committed
334
        self.angles = angles
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        self._videos = videos
        self._exp_conditions = exp_conditions
        self._arena = arena
        self._arena_dims = arena_dims
        self._scales = scales
        self._quality = quality

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
                len(self._videos), len(self._exp_conditions)
            )
        else:
            return "DLC analysis of {} videos".format(len(self._videos))

350
    def get_coords(
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        self,
        center: str = "arena",
        polar: bool = False,
        speed: int = 0,
        length: str = None,
        align: bool = False,
    ) -> Table_dict:
        """
        Returns a table_dict object with the coordinates of each animal as values.

            Parameters:
                - center (str): name of the body part to which the positions will be centered.
                If false, the raw data is returned; if 'arena' (default), coordinates are
                centered in the pitch
                - polar (bool): states whether the coordinates should be converted to polar values
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.
                - align (bool): selects the body part to which later processes will align the frames with
                (see preprocess in table_dict documentation).

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """

377
378
        tabs = deepcopy(self._tables)

379
380
381
382
        if polar:
            for key, tab in tabs.items():
                tabs[key] = tab2polar(tab)

383
        if center == "arena":
384
            if self._arena == "circular":
385

386
387
                for i, (key, value) in enumerate(tabs.items()):

388
389
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
390
391
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
392
393
394
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
395
396
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
397
398
399
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
400
401
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
402
403
404
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
405
406
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
407
                        )
408

409
410
411
412
        elif type(center) == str and center != "arena":

            for i, (key, value) in enumerate(tabs.items()):

413
414
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
415
416
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
417
418

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
419
420
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
421
422
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
423
424
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
425
426

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
427
428
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
429

430
                tabs[key] = value.loc[
431
432
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
433

lucas_miranda's avatar
lucas_miranda committed
434
        if speed:
lucas_miranda's avatar
lucas_miranda committed
435
436
437
            for key, tab in tabs.items():
                vel = rolling_speed(tab, deriv=speed + 1, center=center)
                tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
438

439
440
441
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
442
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
443
                ).astype('timedelta64[s]')
444

445
446
        if align:
            assert (
447
                align in list(tabs.values())[0].columns.levels[0]
448
449
450
451
452
453
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
454
455
456
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
457
458
                tabs[key] = tab[columns]

459
460
461
462
463
464
465
466
467
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
        )

468
469
470
471
472
473
474
475
476
477
478
479
480
    def get_distances(self, speed: int = 0, length: str = None) -> Table_dict:
        """
        Returns a table_dict object with the distances between body parts animal as values.

            Parameters:
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
481
482
483

        tabs = deepcopy(self.distances)

lucas_miranda's avatar
lucas_miranda committed
484
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
485
486

            if speed:
lucas_miranda's avatar
lucas_miranda committed
487
488
489
                for key, tab in tabs.items():
                    vel = rolling_speed(tab, deriv=speed + 1, typ="dists")
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
490

491
492
493
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
494
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
495
                    ).astype('timedelta64[s]')
496

lucas_miranda's avatar
lucas_miranda committed
497
498
            return table_dict(tabs, typ="dists")

499
500
501
502
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
        )

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    def get_angles(
        self, degrees: bool = False, speed: int = 0, length: str = None
    ) -> Table_dict:
        """
        Returns a table_dict object with the angles between body parts animal as values.

            Parameters:
                - angles (bool): if True, returns the angles in degrees. Radians (default) are returned otherwise.
                - speed (int): states the derivative of the positions to report. Speed is returned if 1,
                acceleration if 2, jerk if 3, etc.
                - length (str): length of the video in a datetime compatible format (hh::mm:ss). If stated, the index
                of the stored dataframes will reflect the actual timing in the video.

            Returns:
                tab_dict (Table_dict): table_dict object containing all the computed information
        """
lucas_miranda's avatar
lucas_miranda committed
519
520
521

        tabs = deepcopy(self.angles)

lucas_miranda's avatar
lucas_miranda committed
522
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
523
524
525
526
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
lucas_miranda's avatar
lucas_miranda committed
527
528
529
                for key, tab in tabs.items():
                    vel = rolling_speed(tab, deriv=speed + 1, typ="angles")
                    tabs[key] = vel
lucas_miranda's avatar
lucas_miranda committed
530

531
532
533
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
534
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
535
                    ).astype('timedelta64[s]')
536

lucas_miranda's avatar
lucas_miranda committed
537
            return table_dict(tabs, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
538
539
540

        raise ValueError("Angles not computed. Read the documentation for more details")

541
542
543
    def get_videos(self, play: bool = False):
        """Retuens the videos associated with the dataset as a list."""

544
545
546
547
548
549
550
        if play:
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
551
552
        """Returns the stored dictionary with experimental conditions per subject"""

553
554
        return self._exp_conditions

555
556
557
    def get_quality(self, report: bool = False):
        """Retrieves a dictionary with the tagging quality per video, as reported by DLC"""

558
559
560
561
562
563
564
565
566
567
568
        if report:
            profile = ProfileReport(
                self._quality[report],
                title="Quality Report, {}".format(report),
                html={"style": {"full_width": True}},
            )
            return profile
        return self._quality

    @property
    def get_arenas(self):
569
570
        """Retrieves all available information associated with the arena"""

571
572
        return self._arena, self._arena_dims, self._scales

573
574
575
    def rule_based_annotation(self):
        pass

576
577

class table_dict(dict):
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    """

    Main class for storing a single dataset as a dictionary with individuals as keys and pandas.DataFrames as values.
    Includes methods for generating training and testing datasets for the autoencoders.

    """

    def __init__(
        self,
        tabs: Coordinates,
        typ: str,
        arena: str = None,
        arena_dims: np.array = None,
        center: str = None,
        polar: bool = None,
    ):
594
595
596
597
598
599
600
        super().__init__(tabs)
        self._type = typ
        self._center = center
        self._polar = polar
        self._arena = arena
        self._arena_dims = arena_dims

601
    def filter(self, keys: list) -> Table_dict:
602
603
604
605
606
        """Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
         for selecting data coming from videos of a specified condition."""

        assert np.all([k in self.keys() for k in keys]), "Invalid keys selected"

607
608
609
        return table_dict(
            {k: value for k, value in self.items() if k in keys}, self._type
        )
610

lucas_miranda's avatar
lucas_miranda committed
611
    # noinspection PyTypeChecker
612
613
614
615
    def plot_heatmaps(
        self, bodyparts: list, save: bool = False, i: int = 0
    ) -> plt.figure:
        """Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
616
617
618
619
620
621
622

        if self._type != "coords" or self._polar:
            raise NotImplementedError(
                "Heatmaps only available for cartesian coordinates. Set polar to False in get_coordinates and try again"
            )

        if not self._center:
lucas_miranda's avatar
lucas_miranda committed
623
            warnings.warn("Heatmaps look better if you center the data")
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

        if self._arena == "circular":
            x_lim = (
                [-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
                if self._center
                else [0, self._arena_dims[i][0]]
            )
            y_lim = (
                [-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
                if self._center
                else [0, self._arena_dims[i][1]]
            )

            plot_heatmap(
                list(self.values())[i], bodyparts, xlim=x_lim, ylim=y_lim, save=save,
            )

641
642
643
    def get_training_set(self, test_videos: int = 0) -> Tuple[np.ndarray, np.ndarray]:
        """Generates training and test sets as numpy.array objects for model training"""

644
        rmax = max([i.shape[0] for i in self.values()])
645
        raw_data = np.array(
646
647
            [np.pad(v, ((0, rmax - v.shape[0]), (0, 0))) for v in self.values()]
        )
648
649
650
651
652
653
654
655
656
        test_index = np.random.choice(range(len(raw_data)), test_videos, replace=False)

        X_test = []
        if test_videos > 0:
            X_test = np.concatenate(list(raw_data[test_index]))
            X_train = np.concatenate(list(np.delete(raw_data, test_index, axis=0)))

        else:
            X_train = np.concatenate(list(raw_data))
657

658
        return X_train, X_test
659

lucas_miranda's avatar
lucas_miranda committed
660
    # noinspection PyTypeChecker,PyGlobalUndefined
661
    def preprocess(
662
        self,
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        window_size: int = 1,
        window_step: int = 1,
        scale: str = "standard",
        test_videos: int = 0,
        verbose: bool = False,
        conv_filter: bool = None,
        sigma: float = 1.0,
        shift: float = 0.0,
        shuffle: bool = False,
        align: str = False,
    ) -> np.ndarray:
        """

        Main method for preprocessing the loaded dataset. Capable of returning training
        and test sets ready for model training.

            Parameters:
                - window_size (int): Size of the sliding window to pass through the data to generate training instances
                - window_step (int): Step to take when sliding the window. If 1, a true sliding window is used;
                if equal to window_size, the data is split into non-overlapping chunks.
                - scale (str): Data scaling method. Must be one of 'standard' (default; recommended) and 'minmax'.
                - test_videos (int): Number of videos to use when generating the test set.
                If 0, no test set is generated (not recommended).
                - verbose (bool): prints job information if True
                - conv_filter (bool): must be one of None, 'gaussian'. If not None, convolves each instance
                with the specified kernel.
                - sigma (float): usable only if conv_filter is 'gaussian'. Standard deviation of the kernel to use.
                - shift (float): usable only if conv_filter is 'gaussian'. Shift from mean zero of the kernel to use.
                - shuffle (bool): Shuffles the data instances if True. In most use cases, it should be True for training
                and False for prediction.
                - align (bool): If "all", rotates all data instances to align the center -> align (selected before
                when calling get_coords) axis with the y-axis of the cartesian plane. If 'center', rotates all instances
                using the angle of the central frame of the sliding window. This way rotations of the animal are caught
                as well. It doesn't do anything if False.

            Returns:
                - X_train (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all training videos
                - X_test (np.ndarray): 3d dataset with shape (instances, sliding_window_size, features)
                generated from all test videos (if test_videos > 0)

        """
705

lucas_miranda's avatar
lucas_miranda committed
706
        global g
707
        X_train, X_test = self.get_training_set(test_videos)
708
709
710
711
712

        if scale:
            if verbose:
                print("Scaling data...")

713
            if scale == "standard":
714
                scaler = StandardScaler()
715
            elif scale == "minmax":
716
                scaler = MinMaxScaler()
717
718
719
720
            else:
                raise ValueError(
                    "Invalid scaler. Select one of standard, minmax or None"
                )
721

722
723
724
725
            X_train = scaler.fit_transform(
                X_train.reshape(-1, X_train.shape[-1])
            ).reshape(X_train.shape)

726
            if scale == "standard":
727
                assert np.allclose(np.mean(X_train), 0)
lucas_miranda's avatar
lucas_miranda committed
728
                assert np.allclose(np.std(X_train), 1)
729

730
            if test_videos:
731
732
733
734
735
736
737
                X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
                    X_test.shape
                )

            if verbose:
                print("Done!")

738
739
740
        if align == "all":
            X_train = align_trajectories(X_train, align)

741
742
        X_train = rolling_window(X_train, window_size, window_step)

743
744
        if align == "center":
            X_train = align_trajectories(X_train, align)
745

lucas_miranda's avatar
lucas_miranda committed
746
        if conv_filter == "gaussian":
747
            r = range(-int(window_size / 2), int(window_size / 2) + 1)
748
            r = [i - shift for i in r]
749
750
751
752
753
754
755
756
757
            g = np.array(
                [
                    1
                    / (sigma * np.sqrt(2 * np.pi))
                    * np.exp(-float(x) ** 2 / (2 * sigma ** 2))
                    for x in r
                ]
            )
            g /= np.max(g)
lucas_miranda's avatar
lucas_miranda committed
758
            X_train = X_train * g.reshape([1, window_size, 1])
759

760
        if test_videos:
761
762
763
764

            if align == "all":
                X_test = align_trajectories(X_test, align)

765
            X_test = rolling_window(X_test, window_size, window_step)
766

767
768
            if align == "center":
                X_test = align_trajectories(X_test, align)
769

lucas_miranda's avatar
lucas_miranda committed
770
            if conv_filter == "gaussian":
lucas_miranda's avatar
lucas_miranda committed
771
                X_test = X_test * g.reshape([1, window_size, 1])
772

773
            if shuffle:
774
775
776
777
778
779
                X_train = X_train[
                    np.random.choice(X_train.shape[0], X_train.shape[0], replace=False)
                ]
                X_test = X_test[
                    np.random.choice(X_test.shape[0], X_test.shape[0], replace=False)
                ]
780

781
782
            return X_train, X_test

783
        if shuffle:
784
785
786
            X_train = X_train[
                np.random.choice(X_train.shape[0], X_train.shape[0], replace=False)
            ]
787

788
789
        return X_train

790
791
792
793
794
795
    def random_projection(
        self, n_components: int = None, sample: int = 1000
    ) -> Tuple[Any, Any]:
        """Returns a training set generated from the 2D original data (time x features) and a random projection
        to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
        performance or visualization reasons"""
796

lucas_miranda's avatar
lucas_miranda committed
797
        X = self.get_training_set()[0]
798
799
800
801
802
803
804
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

        rproj = random_projection.GaussianRandomProjection(n_components=n_components)
        X = rproj.fit_transform(X)

        return X, rproj

805
806
807
808
809
810
    def pca(
        self, n_components: int = None, sample: int = 1000, kernel: str = "linear"
    ) -> Tuple[Any, Any]:
        """Returns a training set generated from the 2D original data (time x features) and a PCA projection
        to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
        performance or visualization reasons"""
811

lucas_miranda's avatar
lucas_miranda committed
812
        X = self.get_training_set()[0]
813
814
815
816
817
818
819
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

        pca = KernelPCA(n_components=n_components, kernel=kernel)
        X = pca.fit_transform(X)

        return X, pca

820
821
822
823
824
825
    def tsne(
        self, n_components: int = None, sample: int = 1000, perplexity: int = 30
    ) -> Tuple[Any, Any]:
        """Returns a training set generated from the 2D original data (time x features) and a PCA projection
        to a n_components space. The sample parameter allows the user to randomly pick a subset of the data for
        performance or visualization reasons"""
826

lucas_miranda's avatar
lucas_miranda committed
827
        X = self.get_training_set()[0]
828
829
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

830
        tsne = TSNE(n_components=n_components, perplexity=perplexity)
831
832
833
        X = tsne.fit_transform(X)

        return X, tsne
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856


def merge_tables(*args):
    """

    Takes a number of table_dict objects and merges them
    Returns a table_dict object of type 'merged'

    """
    merged_dict = {key: [] for key in args[0].keys()}
    for tabdict in args:
        for key, val in tabdict.items():
            merged_dict[key].append(val)

    merged_tables = table_dict(
        {
            key: pd.concat(val, axis=1, ignore_index=True)
            for key, val in merged_dict.items()
        },
        typ="merged",
    )

    return merged_tables
857
858
859
860
861


# TODO:
#   - Generate ragged training array using a metric (acceleration, maybe?)
#   - Use something like Dynamic Time Warping to put all instances in the same length
862
#   - add rule_based_annotation method to coordinates class!!
863
864
865
#   - with the current implementation, preprocess can't fully work on merged table_dict instances.
#   While some operations (mainly alignment) should be carried out before merging, others require
#   the whole dataset to function properly.