models.py 21.4 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

deep autoencoder models for unsupervised pose detection

"""
10

lucas_miranda's avatar
lucas_miranda committed
11
from typing import Any, Dict, Tuple
12
from tensorflow.keras import backend as K
13
from tensorflow.keras import Input, Model, Sequential
14
from tensorflow.keras.activations import softplus
15
from tensorflow.keras.callbacks import LambdaCallback
16
from tensorflow.keras.constraints import UnitNorm
17
from tensorflow.keras.initializers import he_uniform, Orthogonal
18
from tensorflow.keras.layers import BatchNormalization, Bidirectional
19
from tensorflow.keras.layers import Dense, Dropout, LSTM
20
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
21
from tensorflow.keras.losses import Huber
22
from tensorflow.keras.optimizers import Nadam
23
import deepof.model_utils
24
import tensorflow as tf
25
26
27
28
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers
29
30


lucas_miranda's avatar
lucas_miranda committed
31
# noinspection PyDefaultArgument
32
class SEQ_2_SEQ_AE:
lucas_miranda's avatar
lucas_miranda committed
33
34
35
36
37
38
39
40
41
42
43
44
    """

        Simple sequence to sequence autoencoder implemented with tf.keras

            Parameters:
                -

            Returns:
                -

        """

45
    def __init__(
lucas_miranda's avatar
lucas_miranda committed
46
        self, architecture_hparams: Dict = {}, huber_delta: float = 100.0,
47
    ):
lucas_miranda's avatar
lucas_miranda committed
48
49
50
51
52
53
54
55
56
        self.hparams = self.get_hparams(architecture_hparams)
        self.CONV_filters = self.hparams["units_conv"]
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
        self.learn_rate = self.hparams["learning_rate"]
57
        self.delta = huber_delta
58

lucas_miranda's avatar
lucas_miranda committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    @staticmethod
    def get_hparams(hparams):
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
            "units_conv": 256,
            "units_lstm": 256,
            "units_dense2": 64,
            "dropout_rate": 0.25,
            "encoding": 16,
            "learning_rate": 1e-5,
        }

        for k, v in hparams.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
lucas_miranda's avatar
lucas_miranda committed
79

80
81
82
83
84
85
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
86
            activation="elu",
87
            kernel_initializer=he_uniform(),
88
        )
89
        Model_E1 = Bidirectional(
90
            LSTM(
91
92
                self.LSTM_units_1,
                activation="tanh",
93
                recurrent_activation="sigmoid",
94
                return_sequences=True,
95
                kernel_constraint=UnitNorm(axis=0),
96
97
            )
        )
98
        Model_E2 = Bidirectional(
99
            LSTM(
100
101
                self.LSTM_units_2,
                activation="tanh",
102
                recurrent_activation="sigmoid",
103
                return_sequences=False,
104
                kernel_constraint=UnitNorm(axis=0),
105
106
            )
        )
107
        Model_E3 = Dense(
108
            self.DENSE_1,
109
            activation="elu",
110
111
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
112
113
        )
        Model_E4 = Dense(
114
            self.DENSE_2,
115
            activation="elu",
116
117
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
118
        )
119
120
        Model_E5 = Dense(
            self.ENCODING,
121
            activation="elu",
122
            kernel_constraint=UnitNorm(axis=1),
123
124
125
            activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
                2, weightage=1.0
            ),
126
            kernel_initializer=Orthogonal(),
127
128
129
        )

        # Decoder layers
130
131
132
133
134
135
136
137
138
        Model_D0 = deepof.model_utils.DenseTranspose(
            Model_E5, activation="elu", output_dim=self.ENCODING,
        )
        Model_D1 = deepof.model_utils.DenseTranspose(
            Model_E4, activation="elu", output_dim=self.DENSE_2,
        )
        Model_D2 = deepof.model_utils.DenseTranspose(
            Model_E3, activation="elu", output_dim=self.DENSE_1,
        )
lucas_miranda's avatar
lucas_miranda committed
139
        Model_D3 = RepeatVector(input_shape[1])
140
        Model_D4 = Bidirectional(
141
            LSTM(
142
143
                self.LSTM_units_1,
                activation="tanh",
144
                recurrent_activation="sigmoid",
145
                return_sequences=True,
146
                kernel_constraint=UnitNorm(axis=1),
147
148
            )
        )
149
        Model_D5 = Bidirectional(
150
            LSTM(
151
152
                self.LSTM_units_1,
                activation="sigmoid",
153
                recurrent_activation="sigmoid",
154
                return_sequences=True,
155
                kernel_constraint=UnitNorm(axis=1),
156
157
158
            )
        )

lucas_miranda's avatar
lucas_miranda committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        )

    def build(self, input_shape: tuple,) -> Tuple[Any, Any, Any]:
        """Builds the tf.keras model"""

        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        ) = self.get_layers(input_shape)

192
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
193
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
lucas_miranda's avatar
lucas_miranda committed
194
        encoder.add(Input(shape=input_shape[1:]))
195
        encoder.add(Model_E0)
196
        encoder.add(BatchNormalization())
197
        encoder.add(Model_E1)
198
        encoder.add(BatchNormalization())
199
        encoder.add(Model_E2)
200
        encoder.add(BatchNormalization())
201
        encoder.add(Model_E3)
202
        encoder.add(BatchNormalization())
203
        encoder.add(Dropout(self.DROPOUT_RATE))
204
        encoder.add(Model_E4)
205
        encoder.add(BatchNormalization())
206
207
        encoder.add(Model_E5)

208
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
209
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
210
        decoder.add(Model_D0)
211
        decoder.add(BatchNormalization())
212
        decoder.add(Model_D1)
213
        decoder.add(BatchNormalization())
214
        decoder.add(Model_D2)
215
        decoder.add(BatchNormalization())
216
        decoder.add(Model_D3)
217
        decoder.add(Model_D4)
218
        decoder.add(BatchNormalization())
219
        decoder.add(Model_D5)
lucas_miranda's avatar
lucas_miranda committed
220
        decoder.add(TimeDistributed(Dense(input_shape[2])))
221

lucas_miranda's avatar
lucas_miranda committed
222
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
223
224

        model.compile(
225
            loss=Huber(reduction="sum", delta=self.delta),
226
            optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
227
228
229
            metrics=["mae"],
        )

230
231
        model.build(input_shape)

lucas_miranda's avatar
lucas_miranda committed
232
        return encoder, decoder, model
233
234


lucas_miranda's avatar
lucas_miranda committed
235
# noinspection PyDefaultArgument
236
class SEQ_2_SEQ_GMVAE:
lucas_miranda's avatar
lucas_miranda committed
237
238
239
240
241
242
243
244
245
246
247
248
    """

    Gaussian Mixture Variational Autoencoder for pose motif elucidation.

        Parameters:
            -

        Returns:
            -

    """

249
    def __init__(
250
        self,
lucas_miranda's avatar
lucas_miranda committed
251
252
253
254
255
        architecture_hparams: dict = {},
        loss: str = "ELBO+MMD",
        kl_warmup_epochs: int = 0,
        mmd_warmup_epochs: int = 0,
        number_of_components: int = 1,
256
        predictor: float = True,
lucas_miranda's avatar
lucas_miranda committed
257
258
259
        overlap_loss: bool = False,
        entropy_reg_weight: float = 0.0,
        initialiser_iters: int = int(1e5),
260
        huber_delta: float = 100.0,
261
    ):
lucas_miranda's avatar
lucas_miranda committed
262
263
264
265
266
267
268
269
270
271
        self.hparams = self.get_hparams(architecture_hparams)
        self.batch_size = self.hparams["batch_size"]
        self.CONV_filters = self.hparams["units_conv"]
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
        self.learn_rate = self.hparams["learning_rate"]
272
        self.loss = loss
lucas_miranda's avatar
lucas_miranda committed
273
        self.prior = "standard_normal"
274
275
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
276
        self.number_of_components = number_of_components
277
        self.predictor = predictor
278
        self.overlap_loss = overlap_loss
279
        self.entropy_reg_weight = entropy_reg_weight
280
        self.initialiser_iters = initialiser_iters
281
        self.delta = huber_delta
282

lucas_miranda's avatar
lucas_miranda committed
283
284
285
286
287
288
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    @property
    def prior(self):
289
290
291
        """Property to set the value of the prior
        once the class is instanciated"""

lucas_miranda's avatar
lucas_miranda committed
292
293
294
295
296
        return self._prior

    def get_prior(self):
        """Sets the Variational Autoencoder prior distribution"""

297
        if self.prior == "standard_normal":
298
            init_means = deepof.model_utils.far_away_uniform_initialiser(
299
300
301
302
                shape=(self.number_of_components, self.ENCODING),
                minval=0,
                maxval=5,
                iters=self.initialiser_iters,
303
304
            )

305
306
            self.prior = deepof.model_utils.tfd.mixture.Mixture(
                cat=deepof.model_utils.tfd.categorical.Categorical(
307
308
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
309
                components=[
310
                    deepof.model_utils.tfd.Independent(
311
                        deepof.model_utils.tfd.Normal(loc=init_means[k], scale=1,),
312
313
                        reinterpreted_batch_ndims=1,
                    )
314
                    for k in range(self.number_of_components)
315
                ],
316
            )
317

318
        else:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            raise NotImplementedError(
                "Gaussian Mixtures are currently the only supported prior"
            )

    @staticmethod
    def get_hparams(params: Dict) -> Dict:
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
            "batch_size": 512,
            "units_conv": 256,
            "units_lstm": 256,
            "units_dense2": 64,
            "dropout_rate": 0.25,
            "encoding": 16,
            "learning_rate": 1e-3,
        }

        for k, v in params.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
344
345
346
347
348
349
350

        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
351
            activation="elu",
352
            kernel_initializer=he_uniform(),
353
            use_bias=False,
354
355
356
357
358
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
359
                recurrent_activation="sigmoid",
360
                return_sequences=True,
361
                kernel_constraint=UnitNorm(axis=0),
362
                use_bias=False,
363
364
365
366
367
368
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
369
                recurrent_activation="sigmoid",
370
                return_sequences=False,
371
                kernel_constraint=UnitNorm(axis=0),
372
                use_bias=False,
373
374
375
376
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
377
            activation="elu",
378
            kernel_constraint=UnitNorm(axis=0),
379
            kernel_initializer=he_uniform(),
380
            use_bias=False,
381
382
383
        )
        Model_E4 = Dense(
            self.DENSE_2,
384
            activation="elu",
385
            kernel_constraint=UnitNorm(axis=0),
386
            kernel_initializer=he_uniform(),
387
            use_bias=False,
388
389
390
391
392
393
394
        )

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
395
        Model_D1 = Dense(
396
            self.DENSE_2,
397
            activation="elu",
398
399
            kernel_initializer=he_uniform(),
            use_bias=False,
400
        )
401
        Model_D2 = Dense(
402
            self.DENSE_1,
403
            activation="elu",
404
405
            kernel_initializer=he_uniform(),
            use_bias=False,
406
        )
lucas_miranda's avatar
lucas_miranda committed
407
        Model_D3 = RepeatVector(input_shape[1])
408
409
        Model_D4 = Bidirectional(
            LSTM(
410
                self.LSTM_units_2,
411
                activation="tanh",
412
                recurrent_activation="sigmoid",
413
                return_sequences=True,
414
                kernel_constraint=UnitNorm(axis=1),
415
                use_bias=False,
416
417
418
419
420
421
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
422
                recurrent_activation="sigmoid",
423
                return_sequences=True,
424
                kernel_constraint=UnitNorm(axis=1),
425
                use_bias=False,
426
427
            )
        )
lucas_miranda's avatar
lucas_miranda committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        Model_P1 = Dense(
            self.DENSE_1,
            activation="elu",
            kernel_initializer=he_uniform(),
            use_bias=False,
        )
        Model_P2 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
                use_bias=False,
            )
        )
        Model_P3 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
                use_bias=False,
            )
        )
454

lucas_miranda's avatar
lucas_miranda committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
470
471
472
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
473
474
475
        )

    def build(self, input_shape: Tuple):
476
        """Builds the tf.keras model"""
lucas_miranda's avatar
lucas_miranda committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

        # Instanciate prior
        self.get_prior()

        # Get model layers
        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
497
498
499
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
500
501
        ) = self.get_layers(input_shape)

502
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
503
        x = Input(shape=input_shape[1:])
504
505
506
507
508
509
510
511
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
512
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
513
514
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
515

516
        encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
517
518
519
        z_cat = Dense(self.number_of_components, activation="softmax",)(
            encoding_shuffle
        )
520
        z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat)
521
        z_gauss = Dense(
522
            deepof.model_utils.tfpl.IndependentNormal.params_size(
523
524
                self.ENCODING * self.number_of_components
            ),
525
            activation=None,
526
        )(encoder)
527

528
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
529

lucas_miranda's avatar
lucas_miranda committed
530
        # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
531
        z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
lucas_miranda's avatar
lucas_miranda committed
532

533
        if self.overlap_loss:
534
            z_gauss = deepof.model_utils.Gaussian_mixture_overlap(
535
536
                self.ENCODING, self.number_of_components, loss=self.overlap_loss,
            )(z_gauss)
537

538
539
        z = deepof.model_utils.tfpl.DistributionLambda(
            lambda gauss: deepof.model_utils.tfd.mixture.Mixture(
540
                cat=deepof.model_utils.tfd.categorical.Categorical(probs=gauss[0],),
541
                components=[
542
543
                    deepof.model_utils.tfd.Independent(
                        deepof.model_utils.tfd.Normal(
544
                            loc=gauss[1][..., : self.ENCODING, k],
lucas_miranda's avatar
lucas_miranda committed
545
                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
546
547
548
549
550
551
552
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
        )([z_cat, z_gauss])
553

554
555
        # Define and control custom loss functions
        kl_warmup_callback = False
556
        if "ELBO" in self.loss:
557

558
            kl_beta = deepof.model_utils.K.variable(1.0, name="kl_beta")
559
560
561
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
562
563
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        kl_beta, deepof.model_utils.K.min([epoch / self.kl_warmup, 1])
564
565
566
                    )
                )

567
            z = deepof.model_utils.KLDivergenceLayer(self.prior, weight=kl_beta)(z)
568
569
570
571

        mmd_warmup_callback = False
        if "MMD" in self.loss:

572
            mmd_beta = deepof.model_utils.K.variable(1.0, name="mmd_beta")
573
574
575
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
576
577
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        mmd_beta, deepof.model_utils.K.min([epoch / self.mmd_warmup, 1])
578
579
580
                    )
                )

581
            z = deepof.model_utils.MMDiscrepancyLayer(
582
583
                batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
            )(z)
584
585

        # Define and instantiate generator
586
        generator = Model_D1(z)
587
588
        generator = Model_B1(generator)
        generator = Model_D2(generator)
589
        generator = Model_B2(generator)
590
591
        generator = Model_D3(generator)
        generator = Model_D4(generator)
592
        generator = Model_B3(generator)
593
        generator = Model_D5(generator)
594
        generator = Model_B4(generator)
595
        x_decoded_mean = TimeDistributed(
lucas_miranda's avatar
lucas_miranda committed
596
            Dense(input_shape[2]), name="vaep_reconstruction"
597
598
        )(generator)

599
        if self.predictor > 0:
600
601
            # Define and instantiate predictor
            predictor = Dense(
602
                self.DENSE_2, activation="elu", kernel_initializer=he_uniform()
603
604
            )(z)
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
605
            predictor = Model_P1(predictor)
606
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
607
            predictor = RepeatVector(input_shape[1])(predictor)
lucas_miranda's avatar
lucas_miranda committed
608
            predictor = Model_P2(predictor)
609
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
610
            predictor = Model_P3(predictor)
611
612
            predictor = BatchNormalization()(predictor)
            x_predicted_mean = TimeDistributed(
lucas_miranda's avatar
lucas_miranda committed
613
                Dense(input_shape[2]), name="vaep_prediction"
614
            )(predictor)
615
616

        # end-to-end autoencoder
617
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
618
        grouper = Model(x, z_cat, name="Deep_Gaussian_Mixture_clustering")
619
        # noinspection PyUnboundLocalVariable
620
        gmvaep = Model(
621
622
            inputs=x,
            outputs=(
623
624
625
                [x_decoded_mean, x_predicted_mean]
                if self.predictor > 0
                else x_decoded_mean
626
            ),
627
            name="SEQ_2_SEQ_GMVAE",
628
629
630
631
        )

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
632
        _generator = Model_D1(g)
633
634
        _generator = Model_B1(_generator)
        _generator = Model_D2(_generator)
635
        _generator = Model_B2(_generator)
636
637
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
638
        _generator = Model_B3(_generator)
639
        _generator = Model_D5(_generator)
640
        _generator = Model_B4(_generator)
lucas_miranda's avatar
lucas_miranda committed
641
        _x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator)
642
643
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

644
645
646
647
        def huber_loss(x_, x_decoded_mean_):  # pragma: no cover
            """Computes huber loss with a fixed delta"""

            huber = Huber(reduction="sum", delta=self.delta)
lucas_miranda's avatar
lucas_miranda committed
648
            return input_shape[1:] * huber(x_, x_decoded_mean_)
649
650

        gmvaep.compile(
651
            loss=huber_loss,
652
            optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
653
654
            metrics=["mae"],
            loss_weights=([1, self.predictor] if self.predictor > 0 else [1]),
655
656
        )

657
658
        gmvaep.build(input_shape)

659
660
661
662
663
664
665
666
        return (
            encoder,
            generator,
            grouper,
            gmvaep,
            kl_warmup_callback,
            mmd_warmup_callback,
        )
lucas_miranda's avatar
lucas_miranda committed
667

lucas_miranda's avatar
lucas_miranda committed
668
669
670
671
    @prior.setter
    def prior(self, value):
        self._prior = value

672

673
# TODO:
674
#       - Investigate posterior collapse (L1 as kernel/activity regulariser does not work)
675
#       - Random horizontal flip for data augmentation
676
#       - Align first frame and untamper sliding window (reduce window stride)
677
#       - design clustering-conscious hyperparameter tuning pipeline
678
#       - execute the pipeline ;)