models.py 27.7 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

deep autoencoder models for unsupervised pose detection

"""
10

lucas_miranda's avatar
lucas_miranda committed
11
from typing import Any, Dict, Tuple
lucas_miranda's avatar
lucas_miranda committed
12
13
14

import tensorflow as tf
import tensorflow_probability as tfp
15
from tensorflow.keras import Input, Model, Sequential
16
from tensorflow.keras.activations import softplus
17
from tensorflow.keras.constraints import UnitNorm
18
from tensorflow.keras.initializers import he_uniform, Orthogonal
19
from tensorflow.keras.layers import BatchNormalization, Bidirectional
20
from tensorflow.keras.layers import Dense, Dropout, LSTM
21
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
lucas_miranda's avatar
lucas_miranda committed
22
from tensorflow.keras.losses import Huber
23
from tensorflow.keras.optimizers import Nadam
lucas_miranda's avatar
lucas_miranda committed
24

25
import deepof.model_utils
26

27
tfb = tfp.bijectors
28
29
tfd = tfp.distributions
tfpl = tfp.layers
30
31


lucas_miranda's avatar
lucas_miranda committed
32
# noinspection PyDefaultArgument
33
class SEQ_2_SEQ_AE:
34
    """  Simple sequence to sequence autoencoder implemented with tf.keras """
lucas_miranda's avatar
lucas_miranda committed
35

36
    def __init__(
37
38
39
        self,
        architecture_hparams: Dict = {},
        huber_delta: float = 1.0,
40
    ):
lucas_miranda's avatar
lucas_miranda committed
41
42
43
44
45
46
47
48
49
        self.hparams = self.get_hparams(architecture_hparams)
        self.CONV_filters = self.hparams["units_conv"]
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
        self.learn_rate = self.hparams["learning_rate"]
50
        self.delta = huber_delta
51

lucas_miranda's avatar
lucas_miranda committed
52
53
54
55
56
57
58
59
60
61
    @staticmethod
    def get_hparams(hparams):
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
            "units_conv": 256,
            "units_lstm": 256,
            "units_dense2": 64,
            "dropout_rate": 0.25,
            "encoding": 16,
62
            "learning_rate": 1e-3,
lucas_miranda's avatar
lucas_miranda committed
63
64
65
66
67
68
69
70
71
        }

        for k, v in hparams.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
lucas_miranda's avatar
lucas_miranda committed
72

73
74
75
76
77
78
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
79
            activation="elu",
80
            kernel_initializer=he_uniform(),
81
        )
82
        Model_E1 = Bidirectional(
83
            LSTM(
84
85
                self.LSTM_units_1,
                activation="tanh",
86
                recurrent_activation="sigmoid",
87
                return_sequences=True,
88
                kernel_constraint=UnitNorm(axis=0),
89
90
            )
        )
91
        Model_E2 = Bidirectional(
92
            LSTM(
93
94
                self.LSTM_units_2,
                activation="tanh",
95
                recurrent_activation="sigmoid",
96
                return_sequences=False,
97
                kernel_constraint=UnitNorm(axis=0),
98
99
            )
        )
100
        Model_E3 = Dense(
101
            self.DENSE_1,
102
            activation="elu",
103
104
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
105
106
        )
        Model_E4 = Dense(
107
            self.DENSE_2,
108
            activation="elu",
109
110
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
111
        )
112
113
        Model_E5 = Dense(
            self.ENCODING,
114
            activation="elu",
115
            kernel_constraint=UnitNorm(axis=1),
116
117
118
            activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
                2, weightage=1.0
            ),
119
            kernel_initializer=Orthogonal(),
120
121
122
        )

        # Decoder layers
123
        Model_D0 = deepof.model_utils.DenseTranspose(
124
125
126
            Model_E5,
            activation="elu",
            output_dim=self.ENCODING,
127
128
        )
        Model_D1 = deepof.model_utils.DenseTranspose(
129
130
131
            Model_E4,
            activation="elu",
            output_dim=self.DENSE_2,
132
133
        )
        Model_D2 = deepof.model_utils.DenseTranspose(
134
135
136
            Model_E3,
            activation="elu",
            output_dim=self.DENSE_1,
137
        )
lucas_miranda's avatar
lucas_miranda committed
138
        Model_D3 = RepeatVector(input_shape[1])
139
        Model_D4 = Bidirectional(
140
            LSTM(
141
142
                self.LSTM_units_1,
                activation="tanh",
143
                recurrent_activation="sigmoid",
144
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
145
                # kernel_constraint=UnitNorm(axis=1),
146
147
            )
        )
148
        Model_D5 = Bidirectional(
149
            LSTM(
150
151
                self.LSTM_units_1,
                activation="sigmoid",
152
                recurrent_activation="sigmoid",
153
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
154
                # kernel_constraint=UnitNorm(axis=1),
155
156
157
            )
        )

lucas_miranda's avatar
lucas_miranda committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        )

173
    def build(
174
175
        self,
        input_shape: tuple,
176
    ) -> Tuple[Any, Any, Any]:
lucas_miranda's avatar
lucas_miranda committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        """Builds the tf.keras model"""

        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        ) = self.get_layers(input_shape)

194
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
195
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
lucas_miranda's avatar
lucas_miranda committed
196
        encoder.add(Input(shape=input_shape[1:]))
197
        encoder.add(Model_E0)
198
        encoder.add(BatchNormalization())
199
        encoder.add(Model_E1)
200
        encoder.add(BatchNormalization())
201
        encoder.add(Model_E2)
202
        encoder.add(BatchNormalization())
203
        encoder.add(Model_E3)
204
        encoder.add(BatchNormalization())
205
        encoder.add(Dropout(self.DROPOUT_RATE))
206
        encoder.add(Model_E4)
207
        encoder.add(BatchNormalization())
208
209
        encoder.add(Model_E5)

210
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
211
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
212
        decoder.add(Model_D0)
213
        decoder.add(BatchNormalization())
214
        decoder.add(Model_D1)
215
        decoder.add(BatchNormalization())
216
        decoder.add(Model_D2)
217
        decoder.add(BatchNormalization())
218
        decoder.add(Model_D3)
219
        decoder.add(Model_D4)
220
        decoder.add(BatchNormalization())
221
        decoder.add(Model_D5)
lucas_miranda's avatar
lucas_miranda committed
222
        decoder.add(TimeDistributed(Dense(input_shape[2])))
223

lucas_miranda's avatar
lucas_miranda committed
224
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
225
226

        model.compile(
lucas_miranda's avatar
lucas_miranda committed
227
            loss=Huber(delta=self.delta),
228
229
230
231
            optimizer=Nadam(
                lr=self.learn_rate,
                clipvalue=0.5,
            ),
232
233
234
            metrics=["mae"],
        )

235
236
        model.build(input_shape)

lucas_miranda's avatar
lucas_miranda committed
237
        return encoder, decoder, model
238
239


lucas_miranda's avatar
lucas_miranda committed
240
# noinspection PyDefaultArgument
241
class SEQ_2_SEQ_GMVAE:
242
    """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
lucas_miranda's avatar
lucas_miranda committed
243

244
    def __init__(
245
246
247
248
249
250
251
252
253
254
255
256
        self,
        architecture_hparams: dict = {},
        batch_size: int = 256,
        compile_model: bool = True,
        encoding: int = 6,
        kl_warmup_epochs: int = 20,
        loss: str = "ELBO",
        mmd_warmup_epochs: int = 20,
        montecarlo_kl: int = 1,
        neuron_control: bool = False,
        number_of_components: int = 1,
        overlap_loss: float = 0.0,
257
        next_sequence_prediction: float = 0.0,
258
        phenotype_prediction: float = 0.0,
259
260
        rule_based_prediction: float = 0.0,
        rule_based_features: int = 6,
261
262
        reg_cat_clusters: bool = False,
        reg_cluster_variance: bool = False,
263
    ):
lucas_miranda's avatar
lucas_miranda committed
264
        self.hparams = self.get_hparams(architecture_hparams)
265
        self.batch_size = batch_size
266
        self.bidirectional_merge = self.hparams["bidirectional_merge"]
lucas_miranda's avatar
lucas_miranda committed
267
268
269
270
        self.CONV_filters = self.hparams["units_conv"]
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
271
        self.ENCODING = encoding
lucas_miranda's avatar
lucas_miranda committed
272
273
274
275
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.clipvalue = self.hparams["clipvalue"]
        self.dense_activation = self.hparams["dense_activation"]
276
        self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
lucas_miranda's avatar
lucas_miranda committed
277
        self.learn_rate = self.hparams["learning_rate"]
lucas_miranda's avatar
lucas_miranda committed
278
        self.lstm_unroll = True
279
        self.compile = compile_model
280
        self.kl_warmup = kl_warmup_epochs
281
        self.loss = loss
282
        self.mc_kl = montecarlo_kl
283
        self.mmd_warmup = mmd_warmup_epochs
284
        self.neuron_control = neuron_control
285
        self.number_of_components = number_of_components
286
        self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue)
287
        self.overlap_loss = overlap_loss
288
        self.next_sequence_prediction = next_sequence_prediction
lucas_miranda's avatar
lucas_miranda committed
289
        self.phenotype_prediction = phenotype_prediction
290
291
        self.rule_based_prediction = rule_based_prediction
        self.rule_based_features = rule_based_features
292
        self.prior = "standard_normal"
293
294
        self.reg_cat_clusters = reg_cat_clusters
        self.reg_cluster_variance = reg_cluster_variance
295

lucas_miranda's avatar
lucas_miranda committed
296
        assert (
297
            "ELBO" in self.loss or "MMD" in self.loss
lucas_miranda's avatar
lucas_miranda committed
298
299
300
301
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    @property
    def prior(self):
302
303
304
        """Property to set the value of the prior
        once the class is instanciated"""

lucas_miranda's avatar
lucas_miranda committed
305
306
307
308
309
        return self._prior

    def get_prior(self):
        """Sets the Variational Autoencoder prior distribution"""

310
        if self.prior == "standard_normal":
311
312
313

            self.prior = tfd.MixtureSameFamily(
                mixture_distribution=tfd.categorical.Categorical(
314
315
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
316
317
                components_distribution=tfd.MultivariateNormalDiag(
                    loc=tf.Variable(
318
                        Orthogonal()(
319
                            [self.number_of_components, self.ENCODING],
320
321
                        ),
                        name="prior_means",
322
323
324
325
                    ),
                    scale_diag=tfp.util.TransformedVariable(
                        tf.ones([self.number_of_components, self.ENCODING]),
                        tfb.Softplus(),
326
                        name="prior_scales",
327
328
                    ),
                ),
329
            )
330

331
        else:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
332
333
334
335
336
337
338
339
340
            raise NotImplementedError(
                "Gaussian Mixtures are currently the only supported prior"
            )

    @staticmethod
    def get_hparams(params: Dict) -> Dict:
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
341
            "bidirectional_merge": "concat",
lucas_miranda's avatar
lucas_miranda committed
342
343
            "clipvalue": 1.0,
            "dense_activation": "relu",
344
            "dense_layers_per_branch": 3,
345
346
            "dropout_rate": 0.05,
            "learning_rate": 1e-3,
347
            "units_conv": 64,
348
            "units_dense2": 32,
349
            "units_lstm": 128,
lucas_miranda's avatar
lucas_miranda committed
350
351
352
353
354
355
356
357
358
        }

        for k, v in params.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
359
360
361
362
363
364

        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
365
            padding="same",
lucas_miranda's avatar
lucas_miranda committed
366
            activation=self.dense_activation,
367
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
368
            use_bias=True,
369
370
371
372
373
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
374
                recurrent_activation="sigmoid",
375
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
376
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
377
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
378
                use_bias=True,
379
380
            ),
            merge_mode=self.bidirectional_merge,
381
382
383
384
385
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
386
                recurrent_activation="sigmoid",
387
                return_sequences=False,
lucas_miranda's avatar
lucas_miranda committed
388
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
389
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
390
                use_bias=True,
391
392
            ),
            merge_mode=self.bidirectional_merge,
393
394
395
        )
        Model_E3 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
396
397
            activation=self.dense_activation,
            # kernel_constraint=UnitNorm(axis=0),
398
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
399
            use_bias=True,
400
        )
401

402
        seq_E = [
403
404
405
406
407
408
409
410
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                # kernel_constraint=UnitNorm(axis=0),
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
            for _ in range(self.dense_layers_per_branch)
411
        ]
412
413
414
415
        Model_E4 = []
        for l in seq_E:
            Model_E4.append(l)
            Model_E4.append(BatchNormalization())
416
417
418
419
420

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
421
422
        Model_B4 = BatchNormalization()

423
        seq_D = [
424
425
426
427
428
429
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
430
            for _ in range(self.dense_layers_per_branch)
431
        ]
432
433
434
435
436
        Model_D1 = []
        for l in seq_D:
            Model_D1.append(l)
            Model_D1.append(BatchNormalization())

437
        Model_D2 = Dense(
438
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
439
            activation=self.dense_activation,
440
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
441
            use_bias=True,
442
        )
lucas_miranda's avatar
lucas_miranda committed
443
        Model_D3 = RepeatVector(input_shape[1])
444
445
        Model_D4 = Bidirectional(
            LSTM(
446
                self.LSTM_units_2,
447
                activation="tanh",
448
                recurrent_activation="sigmoid",
449
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
450
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
451
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
452
                use_bias=True,
453
454
            ),
            merge_mode=self.bidirectional_merge,
455
456
457
458
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
459
                activation="tanh",
460
                recurrent_activation="sigmoid",
461
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
462
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
463
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
464
                use_bias=True,
465
466
            ),
            merge_mode=self.bidirectional_merge,
467
        )
468
469
470
471
472
473
474
475
476
        Model_D6 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="same",
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
            use_bias=True,
        )
lucas_miranda's avatar
lucas_miranda committed
477
478

        # Predictor layers
lucas_miranda's avatar
lucas_miranda committed
479
480
        Model_P1 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
481
            activation=self.dense_activation,
lucas_miranda's avatar
lucas_miranda committed
482
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
483
            use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
484
485
486
487
488
489
490
        )
        Model_P2 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
491
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
492
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
493
                use_bias=True,
494
495
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
496
497
498
499
500
501
502
        )
        Model_P3 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
503
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
504
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
505
                use_bias=True,
506
507
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
508
        )
509

510
        # Phenotype classification layer
511
        Model_PC1 = Dense(
lucas_miranda's avatar
lucas_miranda committed
512
513
514
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
515
        )
lucas_miranda's avatar
lucas_miranda committed
516

517
518
519
520
521
522
523
        # Rule based trait classification layer
        Model_RC1 = Dense(
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
        )

lucas_miranda's avatar
lucas_miranda committed
524
525
526
527
528
529
530
531
532
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
533
            Model_B4,
lucas_miranda's avatar
lucas_miranda committed
534
535
536
537
538
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
539
            Model_D6,
lucas_miranda's avatar
lucas_miranda committed
540
541
542
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
543
            Model_PC1,
544
            Model_RC1,
lucas_miranda's avatar
lucas_miranda committed
545
546
547
        )

    def build(self, input_shape: Tuple):
548
        """Builds the tf.keras model"""
lucas_miranda's avatar
lucas_miranda committed
549
550
551
552
553
554
555
556
557
558
559
560
561
562

        # Instanciate prior
        self.get_prior()

        # Get model layers
        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
563
            Model_B4,
lucas_miranda's avatar
lucas_miranda committed
564
565
566
567
568
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
569
            Model_D6,
lucas_miranda's avatar
lucas_miranda committed
570
571
572
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
573
            Model_PC1,
574
            Model_RC1,
lucas_miranda's avatar
lucas_miranda committed
575
576
        ) = self.get_layers(input_shape)

577
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
578
        x = Input(shape=input_shape[1:])
579
580
581
582
583
584
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
585
586
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
587
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
588
        encoder = Sequential(Model_E4)(encoder)
589

590
        # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
591
592
        z_cat = Dense(
            self.number_of_components,
lucas_miranda's avatar
lucas_miranda committed
593
            name="cluster_assignment",
594
            activation="softmax",
595
            activity_regularizer=(
596
597
598
599
                tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
                if self.reg_cat_clusters
                else None
            ),
600
        )(encoder)
601

602
        z_gauss_mean = Dense(
lucas_miranda's avatar
lucas_miranda committed
603
            tfpl.IndependentNormal.params_size(
604
                self.ENCODING * self.number_of_components
605
606
            )
            // 2,
607
            name="cluster_means",
608
            activation=None,
609
610
            kernel_initializer=Orthogonal(),  # An alternative is a constant initializer with a matrix of values
            # computed from the labels, we could also initialize the prior this way, and update it every N epochs
611
612
613
614
615
616
617
        )(encoder)

        z_gauss_var = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            )
            // 2,
618
            name="cluster_variances",
619
            activation=None,
620
621
622
623
624
625
            activity_regularizer=(
                tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
            ),
        )(encoder)

        z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
626

627
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
628

lucas_miranda's avatar
lucas_miranda committed
629
        # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
630
631
        if self.neuron_control:
            z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
lucas_miranda's avatar
lucas_miranda committed
632

633
        if self.overlap_loss:
634
            z_gauss = deepof.model_utils.Cluster_overlap(
635
636
                self.ENCODING,
                self.number_of_components,
637
                loss=self.overlap_loss,
638
            )(z_gauss)
639

lucas_miranda's avatar
lucas_miranda committed
640
        z = tfpl.DistributionLambda(
641
            make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
642
643
644
                cat=tfd.categorical.Categorical(
                    probs=gauss[0],
                ),
645
                components=[
lucas_miranda's avatar
lucas_miranda committed
646
647
                    tfd.Independent(
                        tfd.Normal(
648
                            loc=gauss[1][..., : self.ENCODING, k],
649
650
651
                            scale=1e-3
                            + softplus(gauss[1][..., self.ENCODING :, k])
                            + 1e-5,
652
653
654
655
656
657
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
658
            convert_to_tensor_fn="sample",
659
            name="encoding_distribution",
660
        )([z_cat, z_gauss])
661

662
        posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
663

664
        # Define and control custom loss functions
665
        if "ELBO" in self.loss:
666
            kl_warm_up_iters = tf.cast(
667
                self.kl_warmup * (input_shape[0] // self.batch_size + 1),
668
                tf.int64,
669
670
            )

671
            # noinspection PyCallingNonCallable
672
            z = deepof.model_utils.KLDivergenceLayer(
673
                distribution_b=self.prior,
674
                test_points_fn=lambda q: q.sample(self.mc_kl),
675
                test_points_reduce_axis=0,
676
                iters=self.optimizer.iterations,
677
                warm_up_iters=kl_warm_up_iters,
678
            )(z)
679
680

        if "MMD" in self.loss:
681
            mmd_warm_up_iters = tf.cast(
682
683
684
                self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
                tf.int64,
            )
685

686
            z = deepof.model_utils.MMDiscrepancyLayer(
687
688
689
                batch_size=self.batch_size,
                prior=self.prior,
                iters=self.optimizer.iterations,
690
                warm_up_iters=mmd_warm_up_iters,
691
            )(z)
692

693
        # Dummy layer with no parameters, to retrieve the previous tensor
694
        z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)
695

696
        # Define and instantiate generator
lucas_miranda's avatar
lucas_miranda committed
697
        g = Input(shape=self.ENCODING)
698
699
        generator = Sequential(Model_D1)(g)
        generator = Model_D2(generator)
700
        generator = Model_B1(generator)
701
        generator = Model_D3(generator)
702
        generator = Model_D4(generator)
703
        generator = Model_B2(generator)
704
        generator = Model_D5(generator)
705
        generator = Model_B3(generator)
706
707
        generator = Model_D6(generator)
        generator = Model_B4(generator)
708
709
710
711
712
713
714
715
716
        x_decoded_mean = Dense(
            tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
        )(generator)
        x_decoded_var = tf.keras.activations.softplus(
            Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
        )
        x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
        x_decoded = tf.keras.layers.concatenate(
            [x_decoded_mean, x_decoded_var], axis=-1
717
        )
718
        x_decoded_mean = tfpl.IndependentNormal(
719
720
721
            event_shape=input_shape[2:],
            convert_to_tensor_fn=tfp.distributions.Distribution.mean,
            name="vae_reconstruction",
722
        )(x_decoded)
723

lucas_miranda's avatar
lucas_miranda committed
724
        # define individual branches as models
725
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
lucas_miranda's avatar
lucas_miranda committed
726
727
        generator = Model(g, x_decoded_mean, name="vae_reconstruction")

728
729
730
        def log_loss(x_true, p_x_q_given_z):
            """Computes the negative log likelihood of the data given
            the output distribution"""
731
            return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
732

733
        model_outs = [generator(encoder.outputs)]
734
735
        model_losses = [log_loss]
        model_metrics = {"vae_reconstruction": ["mae", "mse"]}
736
        loss_weights = [1.0]
lucas_miranda's avatar
lucas_miranda committed
737

738
        if self.next_sequence_prediction > 0:
739
            # Define and instantiate predictor
740
741
742
743
744
745
746
            predictor = Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
            )(z)
            predictor = BatchNormalization()(predictor)
            predictor = Model_P1(predictor)
747
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
748
            predictor = RepeatVector(input_shape[1])(predictor)
lucas_miranda's avatar
lucas_miranda committed
749
            predictor = Model_P2(predictor)
750
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
751
            predictor = Model_P3(predictor)
752
            predictor = BatchNormalization()(predictor)
753
754
755
756
757
758
759
760
761
762
763
764
765
            x_predicted_mean = Dense(
                tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
            )(predictor)
            x_predicted_var = tf.keras.activations.softplus(
                Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
                    predictor
                )
            )
            x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
                x_predicted_var
            )
            x_decoded = tf.keras.layers.concatenate(
                [x_predicted_mean, x_predicted_var], axis=-1
766
767
768
769
770
            )
            x_predicted_mean = tfpl.IndependentNormal(
                event_shape=input_shape[2:],
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="vae_prediction",
771
            )(x_decoded)
772

lucas_miranda's avatar
lucas_miranda committed
773
            model_outs.append(x_predicted_mean)
774
            model_losses.append(log_loss)
775
            model_metrics["vae_prediction"] = ["mae", "mse"]
776
            loss_weights.append(self.next_sequence_prediction)
lucas_miranda's avatar
lucas_miranda committed
777
778
779

        if self.phenotype_prediction > 0:
            pheno_pred = Model_PC1(z)
780
781
782
783
784
785
            pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
            pheno_pred = tfpl.IndependentBernoulli(
                event_shape=1,
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="phenotype_prediction",
            )(pheno_pred)
lucas_miranda's avatar
lucas_miranda committed
786
787

            model_outs.append(pheno_pred)
788
            model_losses.append(log_loss)
789
            model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
lucas_miranda's avatar
lucas_miranda committed
790
791
            loss_weights.append(self.phenotype_prediction)

792
793
        if self.rule_based_prediction > 0:
            rule_pred = Model_RC1(z)
794

795
796
797
798
799
800
801
802
803
804
805
            rule_pred = Dense(
                tfpl.IndependentBernoulli.params_size(self.rule_based_features)
            )(rule_pred)
            rule_pred = tfpl.IndependentBernoulli(
                event_shape=self.rule_based_features,
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="rule_based_prediction",
            )(rule_pred)

            model_outs.append(rule_pred)
            model_losses.append(log_loss)
806
807
808
809
            model_metrics["rule_based_prediction"] = [
                "mae",
                "mse",
            ]
810
811
            loss_weights.append(self.rule_based_prediction)

lucas_miranda's avatar
lucas_miranda committed
812
        # define grouper and end-to-end autoencoder model
813
        grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
814
        gmvaep = Model(
815
            inputs=encoder.inputs,
816
817
818
            outputs=model_outs,
            name="SEQ_2_SEQ_GMVAE",
        )
819

820
821
822
        if self.compile:
            gmvaep.compile(
                loss=model_losses,
823
                optimizer=self.optimizer,
824
                metrics=model_metrics,
825
826
                loss_weights=loss_weights,
            )
827

828
829
        gmvaep.build(input_shape)

830
        return (
831
            encoder,
832
833
834
            generator,
            grouper,
            gmvaep,
835
836
            self.prior,
            posterior,
837
        )
lucas_miranda's avatar
lucas_miranda committed
838

lucas_miranda's avatar
lucas_miranda committed
839
840
841
842
    @prior.setter
    def prior(self, value):
        self._prior = value

843

844
# TODO:
lucas_miranda's avatar
lucas_miranda committed
845
#       - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
846
#       - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
lucas_miranda's avatar
lucas_miranda committed
847
#       - Explore expanding the event dims of the final reconstruction layer
848
#       - Think about gradient penalty to avoid mode collapse (as in WGAN-GP)
849
#       - Think about using spectral normalization