models.py 23 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

deep autoencoder models for unsupervised pose detection

"""
10

lucas_miranda's avatar
lucas_miranda committed
11
from typing import Any, Dict, Tuple
12
from tensorflow.keras import backend as K
13
from tensorflow.keras import Input, Model, Sequential
14
from tensorflow.keras.activations import softplus
15
from tensorflow.keras.callbacks import LambdaCallback
16
from tensorflow.keras.constraints import UnitNorm
17
from tensorflow.keras.initializers import he_uniform, Orthogonal
18
from tensorflow.keras.layers import BatchNormalization, Bidirectional
19
from tensorflow.keras.layers import Dense, Dropout, LSTM
20
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
lucas_miranda's avatar
lucas_miranda committed
21
from tensorflow.keras.losses import BinaryCrossentropy, Huber
22
from tensorflow.keras.optimizers import Nadam
23
import deepof.model_utils
24
import tensorflow as tf
25
26
27
28
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers
29
30


lucas_miranda's avatar
lucas_miranda committed
31
# noinspection PyDefaultArgument
32
class SEQ_2_SEQ_AE:
33
    """  Simple sequence to sequence autoencoder implemented with tf.keras """
lucas_miranda's avatar
lucas_miranda committed
34

35
    def __init__(
36
        self, architecture_hparams: Dict = {}, huber_delta: float = 1.0,
37
    ):
lucas_miranda's avatar
lucas_miranda committed
38
39
40
41
42
43
44
45
46
        self.hparams = self.get_hparams(architecture_hparams)
        self.CONV_filters = self.hparams["units_conv"]
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
        self.learn_rate = self.hparams["learning_rate"]
47
        self.delta = huber_delta
48

lucas_miranda's avatar
lucas_miranda committed
49
50
51
52
53
54
55
56
57
58
    @staticmethod
    def get_hparams(hparams):
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
            "units_conv": 256,
            "units_lstm": 256,
            "units_dense2": 64,
            "dropout_rate": 0.25,
            "encoding": 16,
59
            "learning_rate": 1e-3,
lucas_miranda's avatar
lucas_miranda committed
60
61
62
63
64
65
66
67
68
        }

        for k, v in hparams.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
lucas_miranda's avatar
lucas_miranda committed
69

70
71
72
73
74
75
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
76
            activation="elu",
77
            kernel_initializer=he_uniform(),
78
        )
79
        Model_E1 = Bidirectional(
80
            LSTM(
81
82
                self.LSTM_units_1,
                activation="tanh",
83
                recurrent_activation="sigmoid",
84
                return_sequences=True,
85
                kernel_constraint=UnitNorm(axis=0),
86
87
            )
        )
88
        Model_E2 = Bidirectional(
89
            LSTM(
90
91
                self.LSTM_units_2,
                activation="tanh",
92
                recurrent_activation="sigmoid",
93
                return_sequences=False,
94
                kernel_constraint=UnitNorm(axis=0),
95
96
            )
        )
97
        Model_E3 = Dense(
98
            self.DENSE_1,
99
            activation="elu",
100
101
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
102
103
        )
        Model_E4 = Dense(
104
            self.DENSE_2,
105
            activation="elu",
106
107
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
108
        )
109
110
        Model_E5 = Dense(
            self.ENCODING,
111
            activation="elu",
112
            kernel_constraint=UnitNorm(axis=1),
113
114
115
            activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
                2, weightage=1.0
            ),
116
            kernel_initializer=Orthogonal(),
117
118
119
        )

        # Decoder layers
120
121
122
123
124
125
126
127
128
        Model_D0 = deepof.model_utils.DenseTranspose(
            Model_E5, activation="elu", output_dim=self.ENCODING,
        )
        Model_D1 = deepof.model_utils.DenseTranspose(
            Model_E4, activation="elu", output_dim=self.DENSE_2,
        )
        Model_D2 = deepof.model_utils.DenseTranspose(
            Model_E3, activation="elu", output_dim=self.DENSE_1,
        )
lucas_miranda's avatar
lucas_miranda committed
129
        Model_D3 = RepeatVector(input_shape[1])
130
        Model_D4 = Bidirectional(
131
            LSTM(
132
133
                self.LSTM_units_1,
                activation="tanh",
134
                recurrent_activation="sigmoid",
135
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
136
                # kernel_constraint=UnitNorm(axis=1),
137
138
            )
        )
139
        Model_D5 = Bidirectional(
140
            LSTM(
141
142
                self.LSTM_units_1,
                activation="sigmoid",
143
                recurrent_activation="sigmoid",
144
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
145
                # kernel_constraint=UnitNorm(axis=1),
146
147
148
            )
        )

lucas_miranda's avatar
lucas_miranda committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        )

    def build(self, input_shape: tuple,) -> Tuple[Any, Any, Any]:
        """Builds the tf.keras model"""

        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        ) = self.get_layers(input_shape)

182
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
183
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
lucas_miranda's avatar
lucas_miranda committed
184
        encoder.add(Input(shape=input_shape[1:]))
185
        encoder.add(Model_E0)
186
        encoder.add(BatchNormalization())
187
        encoder.add(Model_E1)
188
        encoder.add(BatchNormalization())
189
        encoder.add(Model_E2)
190
        encoder.add(BatchNormalization())
191
        encoder.add(Model_E3)
192
        encoder.add(BatchNormalization())
193
        encoder.add(Dropout(self.DROPOUT_RATE))
194
        encoder.add(Model_E4)
195
        encoder.add(BatchNormalization())
196
197
        encoder.add(Model_E5)

198
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
199
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
200
        decoder.add(Model_D0)
201
        decoder.add(BatchNormalization())
202
        decoder.add(Model_D1)
203
        decoder.add(BatchNormalization())
204
        decoder.add(Model_D2)
205
        decoder.add(BatchNormalization())
206
        decoder.add(Model_D3)
207
        decoder.add(Model_D4)
208
        decoder.add(BatchNormalization())
209
        decoder.add(Model_D5)
lucas_miranda's avatar
lucas_miranda committed
210
        decoder.add(TimeDistributed(Dense(input_shape[2])))
211

lucas_miranda's avatar
lucas_miranda committed
212
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
213
214

        model.compile(
lucas_miranda's avatar
lucas_miranda committed
215
            loss=Huber(delta=self.delta),
216
            optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
217
218
219
            metrics=["mae"],
        )

220
221
        model.build(input_shape)

lucas_miranda's avatar
lucas_miranda committed
222
        return encoder, decoder, model
223
224


lucas_miranda's avatar
lucas_miranda committed
225
# noinspection PyDefaultArgument
226
class SEQ_2_SEQ_GMVAE:
227
    """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
lucas_miranda's avatar
lucas_miranda committed
228

229
    def __init__(
230
        self,
lucas_miranda's avatar
lucas_miranda committed
231
        architecture_hparams: dict = {},
232
        batch_size: int = 256,
233
        compile_model: bool = True,
234
235
        entropy_reg_weight: float = 0.0,
        huber_delta: float = 1.0,
lucas_miranda's avatar
lucas_miranda committed
236
        initialiser_iters: int = int(1),
lucas_miranda's avatar
lucas_miranda committed
237
        kl_warmup_epochs: int = 0,
238
        loss: str = "ELBO+MMD",
lucas_miranda's avatar
lucas_miranda committed
239
240
        mmd_warmup_epochs: int = 0,
        number_of_components: int = 1,
241
        overlap_loss: float = False,
lucas_miranda's avatar
lucas_miranda committed
242
        phenotype_prediction: float = 0.0,
243
        predictor: float = 0.0,
244
    ):
lucas_miranda's avatar
lucas_miranda committed
245
        self.hparams = self.get_hparams(architecture_hparams)
246
        self.batch_size = batch_size
lucas_miranda's avatar
lucas_miranda committed
247
248
249
250
251
        self.CONV_filters = self.hparams["units_conv"]
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
lucas_miranda's avatar
lucas_miranda committed
252
253
254
255
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.clipvalue = self.hparams["clipvalue"]
        self.dense_activation = self.hparams["dense_activation"]
256
        self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
lucas_miranda's avatar
lucas_miranda committed
257
        self.learn_rate = self.hparams["learning_rate"]
lucas_miranda's avatar
lucas_miranda committed
258
        self.lstm_unroll = True
259
        self.compile = compile_model
260
261
262
        self.delta = huber_delta
        self.entropy_reg_weight = entropy_reg_weight
        self.initialiser_iters = initialiser_iters
263
        self.kl_warmup = kl_warmup_epochs
264
        self.loss = loss
265
        self.mmd_warmup = mmd_warmup_epochs
266
        self.number_of_components = number_of_components
267
        self.overlap_loss = overlap_loss
lucas_miranda's avatar
lucas_miranda committed
268
        self.phenotype_prediction = phenotype_prediction
269
270
        self.predictor = predictor
        self.prior = "standard_normal"
271

lucas_miranda's avatar
lucas_miranda committed
272
273
274
275
276
277
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    @property
    def prior(self):
278
279
280
        """Property to set the value of the prior
        once the class is instanciated"""

lucas_miranda's avatar
lucas_miranda committed
281
282
283
284
285
        return self._prior

    def get_prior(self):
        """Sets the Variational Autoencoder prior distribution"""

286
        if self.prior == "standard_normal":
287
            init_means = deepof.model_utils.far_away_uniform_initialiser(
288
289
290
291
                shape=(self.number_of_components, self.ENCODING),
                minval=0,
                maxval=5,
                iters=self.initialiser_iters,
292
293
            )

lucas_miranda's avatar
lucas_miranda committed
294
295
            self.prior = tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(
296
297
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
298
                components=[
lucas_miranda's avatar
lucas_miranda committed
299
300
                    tfd.Independent(
                        tfd.Normal(loc=init_means[k], scale=1,),
301
302
                        reinterpreted_batch_ndims=1,
                    )
303
                    for k in range(self.number_of_components)
304
                ],
305
            )
306

307
        else:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
308
309
310
311
312
313
314
315
316
            raise NotImplementedError(
                "Gaussian Mixtures are currently the only supported prior"
            )

    @staticmethod
    def get_hparams(params: Dict) -> Dict:
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
317
318
319
            "clipvalue": 0.5,
            "dense_activation": "elu",
            "dense_layers_per_branch": 1,
lucas_miranda's avatar
lucas_miranda committed
320
            "dropout_rate": 0.15,
lucas_miranda's avatar
lucas_miranda committed
321
322
            "encoding": 16,
            "learning_rate": 1e-3,
323
324
325
            "units_conv": 256,
            "units_dense2": 64,
            "units_lstm": 256,
lucas_miranda's avatar
lucas_miranda committed
326
327
328
329
330
331
332
333
334
        }

        for k, v in params.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
335
336
337
338
339
340
341

        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
lucas_miranda's avatar
lucas_miranda committed
342
            activation=self.dense_activation,
343
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
344
            use_bias=True,
345
346
347
348
349
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
350
                recurrent_activation="sigmoid",
351
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
352
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
353
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
354
                use_bias=True,
355
356
357
358
359
360
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
361
                recurrent_activation="sigmoid",
362
                return_sequences=False,
lucas_miranda's avatar
lucas_miranda committed
363
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
364
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
365
                use_bias=True,
366
367
368
369
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
370
371
            activation=self.dense_activation,
            # kernel_constraint=UnitNorm(axis=0),
372
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
373
            use_bias=True,
374
        )
375
376
377
378
379
380
381
382
383
384
385

        Model_E4 = [
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                # kernel_constraint=UnitNorm(axis=0),
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
            for _ in self.dense_layers_per_branch
        ]
386
387
388
389
390
391

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
392
393
394
395
396
397
398
399
400
        Model_D1 = [
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
            for _ in self.dense_layers_per_branch
        ]
401
        Model_D2 = Dense(
402
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
403
            activation=self.dense_activation,
404
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
405
            use_bias=True,
406
        )
lucas_miranda's avatar
lucas_miranda committed
407
        Model_D3 = RepeatVector(input_shape[1])
408
409
        Model_D4 = Bidirectional(
            LSTM(
410
                self.LSTM_units_2,
411
                activation="tanh",
412
                recurrent_activation="sigmoid",
413
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
414
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
415
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
416
                use_bias=True,
417
418
419
420
421
422
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
423
                recurrent_activation="sigmoid",
424
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
425
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
426
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
427
                use_bias=True,
428
429
            )
        )
lucas_miranda's avatar
lucas_miranda committed
430
431

        # Predictor layers
lucas_miranda's avatar
lucas_miranda committed
432
433
        Model_P1 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
434
            activation=self.dense_activation,
lucas_miranda's avatar
lucas_miranda committed
435
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
436
            use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
437
438
439
440
441
442
443
        )
        Model_P2 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
444
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
445
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
446
                use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
447
448
449
450
451
452
453
454
            )
        )
        Model_P3 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
455
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
456
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
457
                use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
458
459
            )
        )
460

lucas_miranda's avatar
lucas_miranda committed
461
        # Phenotype classification layers
462
        Model_PC1 = Dense(
lucas_miranda's avatar
lucas_miranda committed
463
464
465
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
466
        )
lucas_miranda's avatar
lucas_miranda committed
467

lucas_miranda's avatar
lucas_miranda committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
483
484
485
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
486
            Model_PC1,
lucas_miranda's avatar
lucas_miranda committed
487
488
489
        )

    def build(self, input_shape: Tuple):
490
        """Builds the tf.keras model"""
lucas_miranda's avatar
lucas_miranda committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510

        # Instanciate prior
        self.get_prior()

        # Get model layers
        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
511
512
513
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
514
            Model_PC1,
lucas_miranda's avatar
lucas_miranda committed
515
516
        ) = self.get_layers(input_shape)

517
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
518
        x = Input(shape=input_shape[1:])
519
520
521
522
523
524
525
526
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
527
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
528
        encoder = Sequential(Model_E4)(encoder)
529
        encoder = BatchNormalization()(encoder)
530

531
        # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
lucas_miranda's avatar
lucas_miranda committed
532
        z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
533
        z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat)
534
        z_gauss = Dense(
535
            deepof.model_utils.tfpl.IndependentNormal.params_size(
536
537
                self.ENCODING * self.number_of_components
            ),
538
            activation=None,
539
        )(encoder)
540

541
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
542

lucas_miranda's avatar
lucas_miranda committed
543
        # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
544
        z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
lucas_miranda's avatar
lucas_miranda committed
545

546
        if self.overlap_loss:
547
            z_gauss = deepof.model_utils.Gaussian_mixture_overlap(
548
549
                self.ENCODING, self.number_of_components, loss=self.overlap_loss,
            )(z_gauss)
550

551
        z = deepof.model_utils.tfpl.DistributionLambda(
lucas_miranda's avatar
lucas_miranda committed
552
553
            lambda gauss: tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(probs=gauss[0],),
554
                components=[
lucas_miranda's avatar
lucas_miranda committed
555
556
                    tfd.Independent(
                        tfd.Normal(
557
                            loc=gauss[1][..., : self.ENCODING, k],
558
                            scale=softplus(gauss[1][..., self.ENCODING:, k]),
559
560
561
562
563
564
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
565
            convert_to_tensor_fn="sample",
566
        )([z_cat, z_gauss])
567

568
569
        # Define and control custom loss functions
        kl_warmup_callback = False
570
        if "ELBO" in self.loss:
571

572
            kl_beta = deepof.model_utils.K.variable(1.0, name="kl_beta")
573
574
575
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
576
577
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        kl_beta, deepof.model_utils.K.min([epoch / self.kl_warmup, 1])
578
579
580
                    )
                )

581
            z = deepof.model_utils.KLDivergenceLayer(self.prior, weight=kl_beta)(z)
582
583
584
585

        mmd_warmup_callback = False
        if "MMD" in self.loss:

586
            mmd_beta = deepof.model_utils.K.variable(1.0, name="mmd_beta")
587
588
589
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
590
591
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        mmd_beta, deepof.model_utils.K.min([epoch / self.mmd_warmup, 1])
592
593
594
                    )
                )

595
            z = deepof.model_utils.MMDiscrepancyLayer(
596
597
                batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
            )(z)
598
599

        # Define and instantiate generator
600
        generator = Sequential(Model_D1)(z)
601
602
        generator = Model_B1(generator)
        generator = Model_D2(generator)
603
        generator = Model_B2(generator)
604
605
        generator = Model_D3(generator)
        generator = Model_D4(generator)
606
        generator = Model_B3(generator)
607
        generator = Model_D5(generator)
608
        generator = Model_B4(generator)
609
        x_decoded_mean = TimeDistributed(
lucas_miranda's avatar
lucas_miranda committed
610
            Dense(input_shape[2]), name="vaep_reconstruction"
611
612
        )(generator)

lucas_miranda's avatar
lucas_miranda committed
613
614
        model_outs = [x_decoded_mean]
        model_losses = [Huber(delta=self.delta, reduction="sum")]
615
        model_metrics = {"vaep_reconstruction": ["mae", "mse"]}
616
        loss_weights = [1.0]
lucas_miranda's avatar
lucas_miranda committed
617

618
        if self.predictor > 0:
619
620
            # Define and instantiate predictor
            predictor = Dense(
lucas_miranda's avatar
lucas_miranda committed
621
622
623
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
624
625
            )(z)
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
626
            predictor = Model_P1(predictor)
627
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
628
            predictor = RepeatVector(input_shape[1])(predictor)
lucas_miranda's avatar
lucas_miranda committed
629
            predictor = Model_P2(predictor)
630
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
631
            predictor = Model_P3(predictor)
632
633
            predictor = BatchNormalization()(predictor)
            x_predicted_mean = TimeDistributed(
lucas_miranda's avatar
lucas_miranda committed
634
                Dense(input_shape[2]), name="vaep_prediction"
635
            )(predictor)
636

lucas_miranda's avatar
lucas_miranda committed
637
638
            model_outs.append(x_predicted_mean)
            model_losses.append(Huber(delta=self.delta, reduction="sum"))
639
            model_metrics["vaep_prediction"] = ["mae", "mse"]
lucas_miranda's avatar
lucas_miranda committed
640
641
642
643
            loss_weights.append(self.predictor)

        if self.phenotype_prediction > 0:
            pheno_pred = Model_PC1(z)
644
645
646
            pheno_pred = Dense(1, activation="sigmoid", name="phenotype_prediction")(
                pheno_pred
            )
lucas_miranda's avatar
lucas_miranda committed
647
648
649

            model_outs.append(pheno_pred)
            model_losses.append(BinaryCrossentropy())
650
            model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
lucas_miranda's avatar
lucas_miranda committed
651
652
            loss_weights.append(self.phenotype_prediction)

653
        # end-to-end autoencoder
654
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
655
        grouper = Model(x, z_cat, name="Deep_Gaussian_Mixture_clustering")
656
        # noinspection PyUnboundLocalVariable
lucas_miranda's avatar
lucas_miranda committed
657

658
        gmvaep = Model(inputs=x, outputs=model_outs, name="SEQ_2_SEQ_GMVAE",)
659
660
661

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
662
        _generator = Sequential(Model_D1)(g)
663
664
        _generator = Model_B1(_generator)
        _generator = Model_D2(_generator)
665
        _generator = Model_B2(_generator)
666
667
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
668
        _generator = Model_B3(_generator)
669
        _generator = Model_D5(_generator)
670
        _generator = Model_B4(_generator)
lucas_miranda's avatar
lucas_miranda committed
671
        _x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator)
672
673
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

674
675
676
        if self.compile:
            gmvaep.compile(
                loss=model_losses,
lucas_miranda's avatar
lucas_miranda committed
677
                optimizer=Nadam(lr=self.learn_rate, clipvalue=self.clipvalue,),
678
                metrics=model_metrics,
679
680
                loss_weights=loss_weights,
            )
681

682
683
        gmvaep.build(input_shape)

684
685
686
687
688
689
690
691
        return (
            encoder,
            generator,
            grouper,
            gmvaep,
            kl_warmup_callback,
            mmd_warmup_callback,
        )
lucas_miranda's avatar
lucas_miranda committed
692

lucas_miranda's avatar
lucas_miranda committed
693
694
695
696
    @prior.setter
    def prior(self, value):
        self._prior = value

697

698
# TODO:
699
#       - Check KL weight in the overall loss function! Are we scaling the loss components correctly?
lucas_miranda's avatar
lucas_miranda committed
700
701
#       - Check merge mode in LSTM layers. Maybe we can drastically reduce model size!
#       - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
702
#       - Investigate posterior collapse (L1 as kernel/activity regulariser does not work)