models.py 22.3 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

deep autoencoder models for unsupervised pose detection

"""
10

lucas_miranda's avatar
lucas_miranda committed
11
from typing import Any, Dict, Tuple
12
from tensorflow.keras import backend as K
13
from tensorflow.keras import Input, Model, Sequential
14
from tensorflow.keras.activations import softplus
15
from tensorflow.keras.callbacks import LambdaCallback
16
from tensorflow.keras.constraints import UnitNorm
17
from tensorflow.keras.initializers import he_uniform, Orthogonal
18
from tensorflow.keras.layers import BatchNormalization, Bidirectional
19
from tensorflow.keras.layers import Dense, Dropout, LSTM
20
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
lucas_miranda's avatar
lucas_miranda committed
21
from tensorflow.keras.losses import BinaryCrossentropy, Huber
22
from tensorflow.keras.optimizers import Nadam
23
import deepof.model_utils
24
import tensorflow as tf
25
26
27
28
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers
29
30


lucas_miranda's avatar
lucas_miranda committed
31
# noinspection PyDefaultArgument
32
class SEQ_2_SEQ_AE:
33
    """  Simple sequence to sequence autoencoder implemented with tf.keras """
lucas_miranda's avatar
lucas_miranda committed
34

35
    def __init__(
36
        self, architecture_hparams: Dict = {}, huber_delta: float = 1.0,
37
    ):
lucas_miranda's avatar
lucas_miranda committed
38
39
40
41
42
43
44
45
46
        self.hparams = self.get_hparams(architecture_hparams)
        self.CONV_filters = self.hparams["units_conv"]
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
        self.learn_rate = self.hparams["learning_rate"]
47
        self.delta = huber_delta
48

lucas_miranda's avatar
lucas_miranda committed
49
50
51
52
53
54
55
56
57
58
    @staticmethod
    def get_hparams(hparams):
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
            "units_conv": 256,
            "units_lstm": 256,
            "units_dense2": 64,
            "dropout_rate": 0.25,
            "encoding": 16,
59
            "learning_rate": 1e-3,
lucas_miranda's avatar
lucas_miranda committed
60
61
62
63
64
65
66
67
68
        }

        for k, v in hparams.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
lucas_miranda's avatar
lucas_miranda committed
69

70
71
72
73
74
75
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
76
            activation="elu",
77
            kernel_initializer=he_uniform(),
78
        )
79
        Model_E1 = Bidirectional(
80
            LSTM(
81
82
                self.LSTM_units_1,
                activation="tanh",
83
                recurrent_activation="sigmoid",
84
                return_sequences=True,
85
                kernel_constraint=UnitNorm(axis=0),
86
87
            )
        )
88
        Model_E2 = Bidirectional(
89
            LSTM(
90
91
                self.LSTM_units_2,
                activation="tanh",
92
                recurrent_activation="sigmoid",
93
                return_sequences=False,
94
                kernel_constraint=UnitNorm(axis=0),
95
96
            )
        )
97
        Model_E3 = Dense(
98
            self.DENSE_1,
99
            activation="elu",
100
101
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
102
103
        )
        Model_E4 = Dense(
104
            self.DENSE_2,
105
            activation="elu",
106
107
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
108
        )
109
110
        Model_E5 = Dense(
            self.ENCODING,
111
            activation="elu",
112
            kernel_constraint=UnitNorm(axis=1),
113
114
115
            activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
                2, weightage=1.0
            ),
116
            kernel_initializer=Orthogonal(),
117
118
119
        )

        # Decoder layers
120
121
122
123
124
125
126
127
128
        Model_D0 = deepof.model_utils.DenseTranspose(
            Model_E5, activation="elu", output_dim=self.ENCODING,
        )
        Model_D1 = deepof.model_utils.DenseTranspose(
            Model_E4, activation="elu", output_dim=self.DENSE_2,
        )
        Model_D2 = deepof.model_utils.DenseTranspose(
            Model_E3, activation="elu", output_dim=self.DENSE_1,
        )
lucas_miranda's avatar
lucas_miranda committed
129
        Model_D3 = RepeatVector(input_shape[1])
130
        Model_D4 = Bidirectional(
131
            LSTM(
132
133
                self.LSTM_units_1,
                activation="tanh",
134
                recurrent_activation="sigmoid",
135
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
136
                # kernel_constraint=UnitNorm(axis=1),
137
138
            )
        )
139
        Model_D5 = Bidirectional(
140
            LSTM(
141
142
                self.LSTM_units_1,
                activation="sigmoid",
143
                recurrent_activation="sigmoid",
144
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
145
                # kernel_constraint=UnitNorm(axis=1),
146
147
148
            )
        )

lucas_miranda's avatar
lucas_miranda committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        )

    def build(self, input_shape: tuple,) -> Tuple[Any, Any, Any]:
        """Builds the tf.keras model"""

        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        ) = self.get_layers(input_shape)

182
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
183
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
lucas_miranda's avatar
lucas_miranda committed
184
        encoder.add(Input(shape=input_shape[1:]))
185
        encoder.add(Model_E0)
186
        encoder.add(BatchNormalization())
187
        encoder.add(Model_E1)
188
        encoder.add(BatchNormalization())
189
        encoder.add(Model_E2)
190
        encoder.add(BatchNormalization())
191
        encoder.add(Model_E3)
192
        encoder.add(BatchNormalization())
193
        encoder.add(Dropout(self.DROPOUT_RATE))
194
        encoder.add(Model_E4)
195
        encoder.add(BatchNormalization())
196
197
        encoder.add(Model_E5)

198
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
199
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
200
        decoder.add(Model_D0)
201
        decoder.add(BatchNormalization())
202
        decoder.add(Model_D1)
203
        decoder.add(BatchNormalization())
204
        decoder.add(Model_D2)
205
        decoder.add(BatchNormalization())
206
        decoder.add(Model_D3)
207
        decoder.add(Model_D4)
208
        decoder.add(BatchNormalization())
209
        decoder.add(Model_D5)
lucas_miranda's avatar
lucas_miranda committed
210
        decoder.add(TimeDistributed(Dense(input_shape[2])))
211

lucas_miranda's avatar
lucas_miranda committed
212
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
213
214

        model.compile(
lucas_miranda's avatar
lucas_miranda committed
215
            loss=Huber(delta=self.delta),
216
            optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
217
218
219
            metrics=["mae"],
        )

220
221
        model.build(input_shape)

lucas_miranda's avatar
lucas_miranda committed
222
        return encoder, decoder, model
223
224


lucas_miranda's avatar
lucas_miranda committed
225
# noinspection PyDefaultArgument
226
class SEQ_2_SEQ_GMVAE:
227
    """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
lucas_miranda's avatar
lucas_miranda committed
228

229
    def __init__(
230
        self,
lucas_miranda's avatar
lucas_miranda committed
231
        architecture_hparams: dict = {},
232
        batch_size: int = 256,
233
        compile_model: bool = True,
lucas_miranda's avatar
lucas_miranda committed
234
        dense_activation: str = "elu",
235
236
237
        entropy_reg_weight: float = 0.0,
        huber_delta: float = 1.0,
        initialiser_iters: int = int(1e4),
lucas_miranda's avatar
lucas_miranda committed
238
        kl_warmup_epochs: int = 0,
239
        loss: str = "ELBO+MMD",
lucas_miranda's avatar
lucas_miranda committed
240
241
        mmd_warmup_epochs: int = 0,
        number_of_components: int = 1,
242
        overlap_loss: float = False,
lucas_miranda's avatar
lucas_miranda committed
243
        phenotype_prediction: float = 0.0,
244
        predictor: float = 0.0,
245
    ):
lucas_miranda's avatar
lucas_miranda committed
246
        self.hparams = self.get_hparams(architecture_hparams)
247
        self.batch_size = batch_size
lucas_miranda's avatar
lucas_miranda committed
248
        self.CONV_filters = self.hparams["units_conv"]
lucas_miranda's avatar
lucas_miranda committed
249
        self.dense_activation = dense_activation
lucas_miranda's avatar
lucas_miranda committed
250
251
252
253
254
255
256
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
        self.learn_rate = self.hparams["learning_rate"]
257
        self.compile = compile_model
258
259
260
        self.delta = huber_delta
        self.entropy_reg_weight = entropy_reg_weight
        self.initialiser_iters = initialiser_iters
261
        self.kl_warmup = kl_warmup_epochs
262
        self.loss = loss
263
        self.mmd_warmup = mmd_warmup_epochs
264
        self.number_of_components = number_of_components
265
        self.overlap_loss = overlap_loss
lucas_miranda's avatar
lucas_miranda committed
266
        self.phenotype_prediction = phenotype_prediction
267
268
        self.predictor = predictor
        self.prior = "standard_normal"
269

lucas_miranda's avatar
lucas_miranda committed
270
271
272
273
274
275
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    @property
    def prior(self):
276
277
278
        """Property to set the value of the prior
        once the class is instanciated"""

lucas_miranda's avatar
lucas_miranda committed
279
280
281
282
283
        return self._prior

    def get_prior(self):
        """Sets the Variational Autoencoder prior distribution"""

284
        if self.prior == "standard_normal":
285
            init_means = deepof.model_utils.far_away_uniform_initialiser(
286
287
288
289
                shape=(self.number_of_components, self.ENCODING),
                minval=0,
                maxval=5,
                iters=self.initialiser_iters,
290
291
            )

lucas_miranda's avatar
lucas_miranda committed
292
293
            self.prior = tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(
294
295
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
296
                components=[
lucas_miranda's avatar
lucas_miranda committed
297
298
                    tfd.Independent(
                        tfd.Normal(loc=init_means[k], scale=1,),
299
300
                        reinterpreted_batch_ndims=1,
                    )
301
                    for k in range(self.number_of_components)
302
                ],
303
            )
304

305
        else:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            raise NotImplementedError(
                "Gaussian Mixtures are currently the only supported prior"
            )

    @staticmethod
    def get_hparams(params: Dict) -> Dict:
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
            "units_conv": 256,
            "units_lstm": 256,
            "units_dense2": 64,
            "dropout_rate": 0.25,
            "encoding": 16,
            "learning_rate": 1e-3,
        }

        for k, v in params.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
330
331
332
333
334
335
336

        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
lucas_miranda's avatar
lucas_miranda committed
337
            activation=self.dense_activation,
338
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
339
            use_bias=True,
340
341
342
343
344
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
345
                recurrent_activation="sigmoid",
346
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
347
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
348
                use_bias=True,
349
350
351
352
353
354
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
355
                recurrent_activation="sigmoid",
356
                return_sequences=False,
lucas_miranda's avatar
lucas_miranda committed
357
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
358
                use_bias=True,
359
360
361
362
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
363
364
            activation=self.dense_activation,
            # kernel_constraint=UnitNorm(axis=0),
365
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
366
            use_bias=True,
367
368
369
        )
        Model_E4 = Dense(
            self.DENSE_2,
lucas_miranda's avatar
lucas_miranda committed
370
371
            activation=self.dense_activation,
            # kernel_constraint=UnitNorm(axis=0),
372
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
373
            use_bias=True,
374
375
376
377
378
379
380
        )

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
381
        Model_D1 = Dense(
382
            self.DENSE_2,
lucas_miranda's avatar
lucas_miranda committed
383
            activation=self.dense_activation,
384
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
385
            use_bias=True,
386
        )
387
        Model_D2 = Dense(
388
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
389
            activation=self.dense_activation,
390
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
391
            use_bias=True,
392
        )
lucas_miranda's avatar
lucas_miranda committed
393
        Model_D3 = RepeatVector(input_shape[1])
394
395
        Model_D4 = Bidirectional(
            LSTM(
396
                self.LSTM_units_2,
397
                activation="tanh",
398
                recurrent_activation="sigmoid",
399
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
400
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
401
                use_bias=True,
402
403
404
405
406
407
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
408
                recurrent_activation="sigmoid",
409
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
410
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
411
                use_bias=True,
412
413
            )
        )
lucas_miranda's avatar
lucas_miranda committed
414
415

        # Predictor layers
lucas_miranda's avatar
lucas_miranda committed
416
417
        Model_P1 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
418
            activation=self.dense_activation,
lucas_miranda's avatar
lucas_miranda committed
419
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
420
            use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
421
422
423
424
425
426
427
        )
        Model_P2 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
428
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
429
                use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
430
431
432
433
434
435
436
437
            )
        )
        Model_P3 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
438
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
439
                use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
440
441
            )
        )
442

lucas_miranda's avatar
lucas_miranda committed
443
        # Phenotype classification layers
444
        Model_PC1 = Dense(
lucas_miranda's avatar
lucas_miranda committed
445
446
447
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
448
        )
lucas_miranda's avatar
lucas_miranda committed
449

lucas_miranda's avatar
lucas_miranda committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
465
466
467
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
468
            Model_PC1,
lucas_miranda's avatar
lucas_miranda committed
469
470
471
        )

    def build(self, input_shape: Tuple):
472
        """Builds the tf.keras model"""
lucas_miranda's avatar
lucas_miranda committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492

        # Instanciate prior
        self.get_prior()

        # Get model layers
        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
493
494
495
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
496
            Model_PC1,
lucas_miranda's avatar
lucas_miranda committed
497
498
        ) = self.get_layers(input_shape)

499
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
500
        x = Input(shape=input_shape[1:])
501
502
503
504
505
506
507
508
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
509
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
510
511
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
512

513
        # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
lucas_miranda's avatar
lucas_miranda committed
514
        z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
515
        z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat)
516
        z_gauss = Dense(
517
            deepof.model_utils.tfpl.IndependentNormal.params_size(
518
519
                self.ENCODING * self.number_of_components
            ),
520
            activation=None,
521
        )(encoder)
522

523
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
524

lucas_miranda's avatar
lucas_miranda committed
525
        # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
526
        z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
lucas_miranda's avatar
lucas_miranda committed
527

528
        if self.overlap_loss:
529
            z_gauss = deepof.model_utils.Gaussian_mixture_overlap(
530
531
                self.ENCODING, self.number_of_components, loss=self.overlap_loss,
            )(z_gauss)
532

533
        z = deepof.model_utils.tfpl.DistributionLambda(
lucas_miranda's avatar
lucas_miranda committed
534
535
            lambda gauss: tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(probs=gauss[0],),
536
                components=[
lucas_miranda's avatar
lucas_miranda committed
537
538
                    tfd.Independent(
                        tfd.Normal(
539
                            loc=gauss[1][..., : self.ENCODING, k],
lucas_miranda's avatar
lucas_miranda committed
540
                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
541
542
543
544
545
546
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
547
            convert_to_tensor_fn="sample",
548
        )([z_cat, z_gauss])
549

550
551
        # Define and control custom loss functions
        kl_warmup_callback = False
552
        if "ELBO" in self.loss:
553

554
            kl_beta = deepof.model_utils.K.variable(1.0, name="kl_beta")
555
556
557
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
558
559
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        kl_beta, deepof.model_utils.K.min([epoch / self.kl_warmup, 1])
560
561
562
                    )
                )

563
            z = deepof.model_utils.KLDivergenceLayer(self.prior, weight=kl_beta)(z)
564
565
566
567

        mmd_warmup_callback = False
        if "MMD" in self.loss:

568
            mmd_beta = deepof.model_utils.K.variable(1.0, name="mmd_beta")
569
570
571
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
572
573
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        mmd_beta, deepof.model_utils.K.min([epoch / self.mmd_warmup, 1])
574
575
576
                    )
                )

577
            z = deepof.model_utils.MMDiscrepancyLayer(
578
579
                batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
            )(z)
580
581

        # Define and instantiate generator
582
        generator = Model_D1(z)
583
584
        generator = Model_B1(generator)
        generator = Model_D2(generator)
585
        generator = Model_B2(generator)
586
587
        generator = Model_D3(generator)
        generator = Model_D4(generator)
588
        generator = Model_B3(generator)
589
        generator = Model_D5(generator)
590
        generator = Model_B4(generator)
591
        x_decoded_mean = TimeDistributed(
lucas_miranda's avatar
lucas_miranda committed
592
            Dense(input_shape[2]), name="vaep_reconstruction"
593
594
        )(generator)

lucas_miranda's avatar
lucas_miranda committed
595
596
        model_outs = [x_decoded_mean]
        model_losses = [Huber(delta=self.delta, reduction="sum")]
597
        model_metrics = {"vaep_reconstruction": ["mae", "mse"]}
598
        loss_weights = [1.0]
lucas_miranda's avatar
lucas_miranda committed
599

600
        if self.predictor > 0:
601
602
            # Define and instantiate predictor
            predictor = Dense(
lucas_miranda's avatar
lucas_miranda committed
603
604
605
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
606
607
            )(z)
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
608
            predictor = Model_P1(predictor)
609
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
610
            predictor = RepeatVector(input_shape[1])(predictor)
lucas_miranda's avatar
lucas_miranda committed
611
            predictor = Model_P2(predictor)
612
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
613
            predictor = Model_P3(predictor)
614
615
            predictor = BatchNormalization()(predictor)
            x_predicted_mean = TimeDistributed(
lucas_miranda's avatar
lucas_miranda committed
616
                Dense(input_shape[2]), name="vaep_prediction"
617
            )(predictor)
618

lucas_miranda's avatar
lucas_miranda committed
619
620
            model_outs.append(x_predicted_mean)
            model_losses.append(Huber(delta=self.delta, reduction="sum"))
621
            model_metrics["vaep_prediction"] = ["mae", "mse"]
lucas_miranda's avatar
lucas_miranda committed
622
623
624
625
            loss_weights.append(self.predictor)

        if self.phenotype_prediction > 0:
            pheno_pred = Model_PC1(z)
626
627
628
            pheno_pred = Dense(1, activation="sigmoid", name="phenotype_prediction")(
                pheno_pred
            )
lucas_miranda's avatar
lucas_miranda committed
629
630
631

            model_outs.append(pheno_pred)
            model_losses.append(BinaryCrossentropy())
632
            model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
lucas_miranda's avatar
lucas_miranda committed
633
634
            loss_weights.append(self.phenotype_prediction)

635
        # end-to-end autoencoder
636
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
637
        grouper = Model(x, z_cat, name="Deep_Gaussian_Mixture_clustering")
638
        # noinspection PyUnboundLocalVariable
lucas_miranda's avatar
lucas_miranda committed
639

640
        gmvaep = Model(inputs=x, outputs=model_outs, name="SEQ_2_SEQ_GMVAE",)
641
642
643

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
644
        _generator = Model_D1(g)
645
646
        _generator = Model_B1(_generator)
        _generator = Model_D2(_generator)
647
        _generator = Model_B2(_generator)
648
649
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
650
        _generator = Model_B3(_generator)
651
        _generator = Model_D5(_generator)
652
        _generator = Model_B4(_generator)
lucas_miranda's avatar
lucas_miranda committed
653
        _x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator)
654
655
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

656
657
658
659
        if self.compile:
            gmvaep.compile(
                loss=model_losses,
                optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
660
                metrics=model_metrics,
661
662
                loss_weights=loss_weights,
            )
663

664
665
        gmvaep.build(input_shape)

666
667
668
669
670
671
672
673
        return (
            encoder,
            generator,
            grouper,
            gmvaep,
            kl_warmup_callback,
            mmd_warmup_callback,
        )
lucas_miranda's avatar
lucas_miranda committed
674

lucas_miranda's avatar
lucas_miranda committed
675
676
677
678
    @prior.setter
    def prior(self, value):
        self._prior = value

679

680
# TODO:
681
#       - Check KL weight in the overall loss function! Are we scaling the loss components correctly?
lucas_miranda's avatar
lucas_miranda committed
682
683
#       - Check merge mode in LSTM layers. Maybe we can drastically reduce model size!
#       - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
684
#       - Investigate posterior collapse (L1 as kernel/activity regulariser does not work)