preprocess.py 22 KB
Newer Older
lucas_miranda's avatar
lucas_miranda committed
1
2
# @author lucasmiranda42

3
4
5
6
7
8
from collections import defaultdict
from copy import deepcopy
from pandas_profiling import ProfileReport
from sklearn import random_projection
from sklearn.decomposition import KernelPCA
from sklearn.manifold import TSNE
9
from sklearn.preprocessing import MinMaxScaler, StandardScaler
10
11
import os
import warnings
12
13
import networkx as nx

14
from deepof.utils import *
15
from deepof.visuals import *
16

17

18
class project:
lucas_miranda's avatar
lucas_miranda committed
19
20
21
22
23
    """

    Class for loading and preprocessing DLC data of individual and social mice.

    """
24
25

    def __init__(
26
27
28
29
30
31
32
33
34
35
36
37
38
        self,
        video_format=".mp4",
        table_format=".h5",
        path=".",
        exp_conditions=False,
        subset_condition=None,
        arena="circular",
        smooth_alpha=0.1,
        arena_dims=[1],
        distances="All",
        ego=False,
        angles=True,
        connectivity=None,
39
    ):
lucas_miranda's avatar
lucas_miranda committed
40

41
        self.path = path
42
43
        self.video_path = self.path + "/Videos/"
        self.table_path = self.path + "/Tables/"
44
45
46
47
48
49
50
        self.videos = sorted(
            [vid for vid in os.listdir(self.video_path) if vid.endswith(video_format)]
        )
        self.tables = sorted(
            [tab for tab in os.listdir(self.table_path) if tab.endswith(table_format)]
        )
        self.exp_conditions = exp_conditions
51
        self.subset_condition = subset_condition
52
53
54
55
56
57
58
        self.table_format = table_format
        self.video_format = video_format
        self.arena = arena
        self.arena_dims = arena_dims
        self.smooth_alpha = smooth_alpha
        self.distances = distances
        self.ego = ego
lucas_miranda's avatar
lucas_miranda committed
59
60
        self.angles = angles
        self.connectivity = connectivity
61
62
        self.scales = self.get_scale

63
64
65
        # assert [re.findall("(.*)_", vid)[0] for vid in self.videos] == [
        #     re.findall("(.*)\.", tab)[0] for tab in self.tables
        # ], "Video files should match table files"
66
67
68
69
70
71
72
73
74
75
76
77

    def __str__(self):
        if self.exp_conditions:
            return "DLC analysis of {} videos across {} conditions".format(
                len(self.videos), len(self.exp_conditions)
            )
        else:
            return "DLC analysis of {} videos".format(len(self.videos))

    def load_tables(self, verbose):
        """Loads videos and tables into dictionaries"""

lucas_miranda's avatar
lucas_miranda committed
78
79
80
81
82
        if self.table_format not in [".h5", ".csv"]:
            raise NotImplementedError(
                "Tracking files must be in either h5 or csv format"
            )

83
        if verbose:
lucas_miranda's avatar
lucas_miranda committed
84
            print("Loading trajectories...")
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

        if self.table_format == ".h5":
            table_dict = {
                re.findall("(.*?)_", tab)[0]: pd.read_hdf(
                    self.table_path + tab, dtype=float
                )
                for tab in self.tables
            }
        elif self.table_format == ".csv":
            table_dict = {
                re.findall("(.*?)_", tab)[0]: pd.read_csv(
                    self.table_path + tab, dtype=float
                )
                for tab in self.tables
            }

        lik_dict = defaultdict()

        for key, value in table_dict.items():
            x = value.xs("x", level="coords", axis=1, drop_level=False)
            y = value.xs("y", level="coords", axis=1, drop_level=False)
lucas_miranda's avatar
lucas_miranda committed
106
107
108
            lik: pd.DataFrame = value.xs(
                "likelihood", level="coords", axis=1, drop_level=True
            )
109
110

            table_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
lucas_miranda's avatar
lucas_miranda committed
111
            lik_dict[key] = lik
112
113
114

        if self.smooth_alpha:

lucas_miranda's avatar
lucas_miranda committed
115
116
117
            if verbose:
                print("Smoothing trajectories...")

lucas_miranda's avatar
lucas_miranda committed
118
119
120
121
            for key, tab in table_dict.items():
                cols = tab.columns
                smooth = pd.DataFrame(
                    smooth_mult_trajectory(np.array(tab), alpha=self.smooth_alpha)
122
                )
lucas_miranda's avatar
lucas_miranda committed
123
124
                smooth.columns = cols
                table_dict[key] = smooth
125
126
127
128

        for key, tab in table_dict.items():
            table_dict[key] = tab[tab.columns.levels[0][0]]

129
130
131
132
133
134
135
136
137
138
139
140
141
        if self.subset_condition:
            for key, value in table_dict.items():
                lablist = [
                    b
                    for b in value.columns.levels[0]
                    if not b.startswith(self.subset_condition)
                ]

                tabcols = value.drop(
                    lablist, axis=1, level=0
                ).T.index.remove_unused_levels()

                tab = value.loc[
142
143
                    :, [i for i in value.columns.levels[0] if i not in lablist]
                ]
144
145
146
147
148

                tab.columns = tabcols

                table_dict[key] = tab

149
150
151
152
153
154
155
156
        return table_dict, lik_dict

    @property
    def get_scale(self):
        """Returns the arena as recognised from the videos"""

        if self.arena in ["circular"]:

157
158
159
160
161
162
163
164
165
166
167
            scales = []
            for vid_index, _ in enumerate(self.videos):
                scales.append(
                    list(
                        recognize_arena(
                            self.videos,
                            vid_index,
                            path=self.video_path,
                            arena_type=self.arena,
                        )
                        * 2
168
                    )
169
                    + self.arena_dims
170
171
172
173
174
175
176
                )

        else:
            raise NotImplementedError("arenas must be set to one of: 'circular'")

        return np.array(scales)

lucas_miranda's avatar
lucas_miranda committed
177
    def get_distances(self, table_dict, verbose):
178
179
180
        """Computes the distances between all selected bodyparts over time.
           If ego is provided, it only returns distances to a specified bodypart"""

lucas_miranda's avatar
lucas_miranda committed
181
182
183
        if verbose:
            print("Computing distances...")

184
185
186
187
188
189
190
191
192
193
        nodes = self.distances
        if nodes == "All":
            nodes = table_dict[list(table_dict.keys())[0]].columns.levels[0]

        assert [
            i in list(table_dict.values())[0].columns.levels[0] for i in nodes
        ], "Nodes should correspond to existent bodyparts"

        scales = self.scales[:, 2:]

194
        distance_dict = {
195
            key: bpart_distance(tab, scales[i, 1], scales[i, 0],)
lucas_miranda's avatar
lucas_miranda committed
196
            for i, (key, tab) in enumerate(table_dict.items())
197
        }
198

lucas_miranda's avatar
lucas_miranda committed
199
200
        for key in distance_dict.keys():
            distance_dict[key] = distance_dict[key].loc[
201
202
                :, [np.all([i in nodes for i in j]) for j in distance_dict[key].columns]
            ]
lucas_miranda's avatar
lucas_miranda committed
203

204
205
206
        if self.ego:
            for key, val in distance_dict.items():
                distance_dict[key] = val.loc[
207
208
                    :, [dist for dist in val.columns if self.ego in dist]
                ]
209
210
211

        return distance_dict

lucas_miranda's avatar
lucas_miranda committed
212
    def get_angles(self, table_dict, verbose):
lucas_miranda's avatar
lucas_miranda committed
213
214
215
216
217
218
219
220
221
222
223
224
        """

        Computes all the angles between adjacent bodypart trios per video and per frame in the data.
        Parameters (from self):
            connectivity (dictionary): dict stating to which bodyparts each bodypart is connected;
            table_dict (dict of dataframes): tables loaded from the data;

        Output:
            angle_dict (dictionary): dict containing angle dataframes per vido

        """

lucas_miranda's avatar
lucas_miranda committed
225
226
227
        if verbose:
            print("Computing angles...")

lucas_miranda's avatar
lucas_miranda committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        bp_net = nx.Graph(self.connectivity)
        cliques = nx.enumerate_all_cliques(bp_net)
        cliques = [i for i in cliques if len(i) == 3]

        angle_dict = {}
        for key, tab in table_dict.items():

            dats = []
            for clique in cliques:
                dat = pd.DataFrame(
                    angle_trio(np.array(tab[clique]).reshape(3, tab.shape[0], 2))
                ).T

                orders = [[0, 1, 2], [0, 2, 1], [1, 0, 2]]
                dat.columns = [tuple(clique[i] for i in order) for order in orders]

                dats.append(dat)

            dats = pd.concat(dats, axis=1)

            angle_dict[key] = dats

        return angle_dict
251

lucas_miranda's avatar
lucas_miranda committed
252
    def run(self, verbose=False):
253
254
255
256
        """Generates a dataset using all the options specified during initialization"""

        tables, quality = self.load_tables(verbose)
        distances = None
lucas_miranda's avatar
lucas_miranda committed
257
        angles = None
258
259

        if self.distances:
lucas_miranda's avatar
lucas_miranda committed
260
            distances = self.get_distances(tables, verbose)
261

lucas_miranda's avatar
lucas_miranda committed
262
        if self.angles:
lucas_miranda's avatar
lucas_miranda committed
263
            angles = self.get_angles(tables, verbose)
lucas_miranda's avatar
lucas_miranda committed
264

lucas_miranda's avatar
lucas_miranda committed
265
        if verbose:
266
267
268
            print("Done!")

        return coordinates(
lucas_miranda's avatar
lucas_miranda committed
269
270
271
272
273
274
275
276
277
            tables=tables,
            videos=self.videos,
            arena=self.arena,
            arena_dims=self.arena_dims,
            scales=self.scales,
            quality=quality,
            exp_conditions=self.exp_conditions,
            distances=distances,
            angles=angles,
278
279
280
281
282
        )


class coordinates:
    def __init__(
283
284
285
286
287
288
289
290
291
292
        self,
        tables,
        videos,
        arena,
        arena_dims,
        scales,
        quality,
        exp_conditions=None,
        distances=None,
        angles=None,
293
294
295
    ):
        self._tables = tables
        self.distances = distances
lucas_miranda's avatar
lucas_miranda committed
296
        self.angles = angles
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        self._videos = videos
        self._exp_conditions = exp_conditions
        self._arena = arena
        self._arena_dims = arena_dims
        self._scales = scales
        self._quality = quality

    def __str__(self):
        if self._exp_conditions:
            return "Coordinates of {} videos across {} conditions".format(
                len(self._videos), len(self._exp_conditions)
            )
        else:
            return "DLC analysis of {} videos".format(len(self._videos))

312
    def get_coords(
313
        self, center="arena", polar=False, speed=0, length=None, align=False
314
    ):
315
316
        tabs = deepcopy(self._tables)

317
318
319
320
        if polar:
            for key, tab in tabs.items():
                tabs[key] = tab2polar(tab)

321
        if center == "arena":
322
            if self._arena == "circular":
323

324
325
                for i, (key, value) in enumerate(tabs.items()):

326
327
                    try:
                        value.loc[:, (slice("coords"), ["x"])] = (
328
329
                            value.loc[:, (slice("coords"), ["x"])]
                            - self._scales[i][0] / 2
330
331
332
                        )

                        value.loc[:, (slice("coords"), ["y"])] = (
333
334
                            value.loc[:, (slice("coords"), ["y"])]
                            - self._scales[i][1] / 2
335
336
337
                        )
                    except KeyError:
                        value.loc[:, (slice("coords"), ["rho"])] = (
338
339
                            value.loc[:, (slice("coords"), ["rho"])]
                            - self._scales[i][0] / 2
340
341
342
                        )

                        value.loc[:, (slice("coords"), ["phi"])] = (
343
344
                            value.loc[:, (slice("coords"), ["phi"])]
                            - self._scales[i][1] / 2
345
                        )
346

347
348
349
350
        elif type(center) == str and center != "arena":

            for i, (key, value) in enumerate(tabs.items()):

351
352
                try:
                    value.loc[:, (slice("coords"), ["x"])] = value.loc[
353
354
                        :, (slice("coords"), ["x"])
                    ].subtract(value[center]["x"], axis=0)
355
356

                    value.loc[:, (slice("coords"), ["y"])] = value.loc[
357
358
                        :, (slice("coords"), ["y"])
                    ].subtract(value[center]["y"], axis=0)
359
360
                except KeyError:
                    value.loc[:, (slice("coords"), ["rho"])] = value.loc[
361
362
                        :, (slice("coords"), ["rho"])
                    ].subtract(value[center]["rho"], axis=0)
363
364

                    value.loc[:, (slice("coords"), ["phi"])] = value.loc[
365
366
                        :, (slice("coords"), ["phi"])
                    ].subtract(value[center]["phi"], axis=0)
367

368
                tabs[key] = value.loc[
369
370
                    :, [tab for tab in value.columns if tab[0] != center]
                ]
371

lucas_miranda's avatar
lucas_miranda committed
372
373
374
375
376
377
378
379
380
381
382
        if speed:
            for order in range(speed):
                for key, tab in tabs.items():
                    try:
                        cols = tab.columns.levels[0]
                    except AttributeError:
                        cols = tab.columns
                    vel = rolling_speed(tab, typ="coords", order=order + 1)
                    vel.columns = cols
                    tabs[key] = vel

383
384
385
        if length:
            for key, tab in tabs.items():
                tabs[key].index = pd.timedelta_range(
386
                    "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
387
388
                )

389
390
        if align:
            assert (
391
                align in list(tabs.values())[0].columns.levels[0]
392
393
394
395
396
397
            ), "align must be set to the name of a bodypart"

            for key, tab in tabs.items():
                # Bring forward the column to align
                columns = [i for i in tab.columns if align not in i]
                columns = [
398
399
400
                    (align, ("phi" if polar else "x")),
                    (align, ("rho" if polar else "y")),
                ] + columns
401
402
                tabs[key] = tab[columns]

403
404
405
406
407
408
409
410
411
        return table_dict(
            tabs,
            "coords",
            arena=self._arena,
            arena_dims=self._scales,
            center=center,
            polar=polar,
        )

412
    def get_distances(self, speed=0, length=None):
lucas_miranda's avatar
lucas_miranda committed
413
414
415

        tabs = deepcopy(self.distances)

lucas_miranda's avatar
lucas_miranda committed
416
        if self.distances is not None:
lucas_miranda's avatar
lucas_miranda committed
417
418
419
420
421
422
423
424
425
426
427
428

            if speed:
                for order in range(speed):
                    for key, tab in tabs.items():
                        try:
                            cols = tab.columns.levels[0]
                        except AttributeError:
                            cols = tab.columns
                        vel = rolling_speed(tab, typ="dists", order=order + 1)
                        vel.columns = cols
                        tabs[key] = vel

429
430
431
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
432
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
433
434
                    )

lucas_miranda's avatar
lucas_miranda committed
435
436
            return table_dict(tabs, typ="dists")

437
438
439
440
        raise ValueError(
            "Distances not computed. Read the documentation for more details"
        )

441
    def get_angles(self, degrees=False, speed=0, length=None):
lucas_miranda's avatar
lucas_miranda committed
442
443
444

        tabs = deepcopy(self.angles)

lucas_miranda's avatar
lucas_miranda committed
445
        if self.angles is not None:
lucas_miranda's avatar
lucas_miranda committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
            if degrees:
                tabs = {key: np.degrees(tab) for key, tab in tabs.items()}

            if speed:
                for order in range(speed):
                    for key, tab in tabs.items():
                        try:
                            cols = tab.columns.levels[0]
                        except AttributeError:
                            cols = tab.columns
                        vel = rolling_speed(tab, typ="dists", order=order + 1)
                        vel.columns = cols
                        tabs[key] = vel

460
461
462
            if length:
                for key, tab in tabs.items():
                    tabs[key].index = pd.timedelta_range(
463
                        "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
464
465
                    )

lucas_miranda's avatar
lucas_miranda committed
466
            return table_dict(tabs, typ="angles")
lucas_miranda's avatar
lucas_miranda committed
467
468
469

        raise ValueError("Angles not computed. Read the documentation for more details")

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    def get_videos(self, play=False):
        if play:
            raise NotImplementedError

        return self._videos

    @property
    def get_exp_conditions(self):
        return self._exp_conditions

    def get_quality(self, report=False):
        if report:
            profile = ProfileReport(
                self._quality[report],
                title="Quality Report, {}".format(report),
                html={"style": {"full_width": True}},
            )
            return profile
        return self._quality

    @property
    def get_arenas(self):
        return self._arena, self._arena_dims, self._scales


class table_dict(dict):
    def __init__(self, tabs, typ, arena=None, arena_dims=None, center=None, polar=None):
        super().__init__(tabs)
        self._type = typ
        self._center = center
        self._polar = polar
        self._arena = arena
        self._arena_dims = arena_dims

504
505
506
507
508
509
    def filter(self, keys):
        """Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
         for selecting data coming from videos of a specified condition."""

        assert np.all([k in self.keys() for k in keys]), "Invalid keys selected"

510
511
512
        return table_dict(
            {k: value for k, value in self.items() if k in keys}, self._type
        )
513

514
515
516
517
518
519
520
521
    def plot_heatmaps(self, bodyparts, save=False, i=0):

        if self._type != "coords" or self._polar:
            raise NotImplementedError(
                "Heatmaps only available for cartesian coordinates. Set polar to False in get_coordinates and try again"
            )

        if not self._center:
lucas_miranda's avatar
lucas_miranda committed
522
            warnings.warn("Heatmaps look better if you center the data")
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539

        if self._arena == "circular":
            x_lim = (
                [-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
                if self._center
                else [0, self._arena_dims[i][0]]
            )
            y_lim = (
                [-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
                if self._center
                else [0, self._arena_dims[i][1]]
            )

            plot_heatmap(
                list(self.values())[i], bodyparts, xlim=x_lim, ylim=y_lim, save=save,
            )

540
    def get_training_set(self, test_videos=0):
541
        rmax = max([i.shape[0] for i in self.values()])
542
        raw_data = np.array(
543
544
            [np.pad(v, ((0, rmax - v.shape[0]), (0, 0))) for v in self.values()]
        )
545
546
547
548
549
550
551
552
553
        test_index = np.random.choice(range(len(raw_data)), test_videos, replace=False)

        X_test = []
        if test_videos > 0:
            X_test = np.concatenate(list(raw_data[test_index]))
            X_train = np.concatenate(list(np.delete(raw_data, test_index, axis=0)))

        else:
            X_train = np.concatenate(list(raw_data))
554

555
        return X_train, X_test
556
557

    def preprocess(
558
559
560
561
562
563
564
565
566
567
568
        self,
        window_size=1,
        window_step=1,
        scale="standard",
        test_videos=0,
        verbose=False,
        filter=None,
        sigma=None,
        shift=0,
        shuffle=False,
        align=False,
569
    ):
570
        """Builds a sliding window. If specified, splits train and test and
571
572
           Z-scores the data using sklearn's standard scaler"""

573
        X_train, X_test = self.get_training_set(test_videos)
574
575
576
577
578

        if scale:
            if verbose:
                print("Scaling data...")

579
            if scale == "standard":
580
                scaler = StandardScaler()
581
            elif scale == "minmax":
582
                scaler = MinMaxScaler()
583
584
585
586
            else:
                raise ValueError(
                    "Invalid scaler. Select one of standard, minmax or None"
                )
587

588
589
590
591
            X_train = scaler.fit_transform(
                X_train.reshape(-1, X_train.shape[-1])
            ).reshape(X_train.shape)

592
            if scale == "standard":
593
                assert np.allclose(np.mean(X_train), 0)
594
                assert np.allclose(np.std(X_train, ddof=1), 1)
595

596
            if test_videos:
597
598
599
600
601
602
603
                X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
                    X_test.shape
                )

            if verbose:
                print("Done!")

604
605
606
        if align == "all":
            X_train = align_trajectories(X_train, align)

607
608
        X_train = rolling_window(X_train, window_size, window_step)

609
610
        if align == "center":
            X_train = align_trajectories(X_train, align)
611

612
613
        if filter == "gaussian":
            r = range(-int(window_size / 2), int(window_size / 2) + 1)
614
            r = [i - shift for i in r]
615
616
617
618
619
620
621
622
623
624
625
            g = np.array(
                [
                    1
                    / (sigma * np.sqrt(2 * np.pi))
                    * np.exp(-float(x) ** 2 / (2 * sigma ** 2))
                    for x in r
                ]
            )
            g /= np.max(g)
            X_train = X_train * g.reshape(1, window_size, 1)

626
        if test_videos:
627
628
629
630

            if align == "all":
                X_test = align_trajectories(X_test, align)

631
            X_test = rolling_window(X_test, window_size, window_step)
632

633
634
            if align == "center":
                X_test = align_trajectories(X_test, align)
635

636
637
638
            if filter == "gaussian":
                X_test = X_test * g.reshape(1, window_size, 1)

639
            if shuffle:
640
641
642
643
644
645
                X_train = X_train[
                    np.random.choice(X_train.shape[0], X_train.shape[0], replace=False)
                ]
                X_test = X_test[
                    np.random.choice(X_test.shape[0], X_test.shape[0], replace=False)
                ]
646

647
648
            return X_train, X_test

649
        if shuffle:
650
651
652
            X_train = X_train[
                np.random.choice(X_train.shape[0], X_train.shape[0], replace=False)
            ]
653

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        return X_train

    def random_projection(self, n_components=None, sample=1000):

        X = self.get_training_set()
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

        rproj = random_projection.GaussianRandomProjection(n_components=n_components)
        X = rproj.fit_transform(X)

        return X, rproj

    def pca(self, n_components=None, sample=1000, kernel="linear"):

        X = self.get_training_set()
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

        pca = KernelPCA(n_components=n_components, kernel=kernel)
        X = pca.fit_transform(X)

        return X, pca

676
    def tsne(self, n_components=None, sample=1000, perplexity=30):
677
678
679
680

        X = self.get_training_set()
        X = X[np.random.choice(X.shape[0], sample, replace=False), :]

681
        tsne = TSNE(n_components=n_components, perplexity=perplexity)
682
683
684
        X = tsne.fit_transform(X)

        return X, tsne
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707


def merge_tables(*args):
    """

    Takes a number of table_dict objects and merges them
    Returns a table_dict object of type 'merged'

    """
    merged_dict = {key: [] for key in args[0].keys()}
    for tabdict in args:
        for key, val in tabdict.items():
            merged_dict[key].append(val)

    merged_tables = table_dict(
        {
            key: pd.concat(val, axis=1, ignore_index=True)
            for key, val in merged_dict.items()
        },
        typ="merged",
    )

    return merged_tables