models.py 37.8 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

deep autoencoder models for unsupervised pose detection

"""
10

lucas_miranda's avatar
lucas_miranda committed
11
from typing import Any, Dict, Tuple
12
from tensorflow.keras import Input, Model, Sequential
13
from tensorflow.keras.activations import softplus
14
from tensorflow.keras.callbacks import LambdaCallback
15
from tensorflow.keras.constraints import UnitNorm
16
from tensorflow.keras.initializers import he_uniform, Orthogonal
17
from tensorflow.keras.layers import BatchNormalization, Bidirectional
18
from tensorflow.keras.layers import Dense, Dropout, LSTM
19
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
lucas_miranda's avatar
lucas_miranda committed
20
from tensorflow.keras.losses import Huber
21
from tensorflow.keras.optimizers import Nadam
22
import deepof.model_utils
23
import tensorflow as tf
24
25
import tensorflow_probability as tfp

26
tfb = tfp.bijectors
27
28
tfd = tfp.distributions
tfpl = tfp.layers
29
30


lucas_miranda's avatar
lucas_miranda committed
31
# noinspection PyDefaultArgument
32
class SEQ_2_SEQ_AE:
33
    """  Simple sequence to sequence autoencoder implemented with tf.keras """
lucas_miranda's avatar
lucas_miranda committed
34

35
    def __init__(
36
37
38
        self,
        architecture_hparams: Dict = {},
        huber_delta: float = 1.0,
39
    ):
lucas_miranda's avatar
lucas_miranda committed
40
41
42
43
44
45
46
47
48
        self.hparams = self.get_hparams(architecture_hparams)
        self.CONV_filters = self.hparams["units_conv"]
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
        self.ENCODING = self.hparams["encoding"]
        self.learn_rate = self.hparams["learning_rate"]
49
        self.delta = huber_delta
50

lucas_miranda's avatar
lucas_miranda committed
51
52
53
54
55
56
57
58
59
60
    @staticmethod
    def get_hparams(hparams):
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
            "units_conv": 256,
            "units_lstm": 256,
            "units_dense2": 64,
            "dropout_rate": 0.25,
            "encoding": 16,
61
            "learning_rate": 1e-3,
lucas_miranda's avatar
lucas_miranda committed
62
63
64
65
66
67
68
69
70
        }

        for k, v in hparams.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
lucas_miranda's avatar
lucas_miranda committed
71

72
73
74
75
76
77
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
78
            activation="elu",
79
            kernel_initializer=he_uniform(),
80
        )
81
        Model_E1 = Bidirectional(
82
            LSTM(
83
84
                self.LSTM_units_1,
                activation="tanh",
85
                recurrent_activation="sigmoid",
86
                return_sequences=True,
87
                kernel_constraint=UnitNorm(axis=0),
88
89
            )
        )
90
        Model_E2 = Bidirectional(
91
            LSTM(
92
93
                self.LSTM_units_2,
                activation="tanh",
94
                recurrent_activation="sigmoid",
95
                return_sequences=False,
96
                kernel_constraint=UnitNorm(axis=0),
97
98
            )
        )
99
        Model_E3 = Dense(
100
            self.DENSE_1,
101
            activation="elu",
102
103
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
104
105
        )
        Model_E4 = Dense(
106
            self.DENSE_2,
107
            activation="elu",
108
109
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
110
        )
111
112
        Model_E5 = Dense(
            self.ENCODING,
113
            activation="elu",
114
            kernel_constraint=UnitNorm(axis=1),
115
116
117
            activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
                2, weightage=1.0
            ),
118
            kernel_initializer=Orthogonal(),
119
120
121
        )

        # Decoder layers
122
        Model_D0 = deepof.model_utils.DenseTranspose(
123
124
125
            Model_E5,
            activation="elu",
            output_dim=self.ENCODING,
126
127
        )
        Model_D1 = deepof.model_utils.DenseTranspose(
128
129
130
            Model_E4,
            activation="elu",
            output_dim=self.DENSE_2,
131
132
        )
        Model_D2 = deepof.model_utils.DenseTranspose(
133
134
135
            Model_E3,
            activation="elu",
            output_dim=self.DENSE_1,
136
        )
lucas_miranda's avatar
lucas_miranda committed
137
        Model_D3 = RepeatVector(input_shape[1])
138
        Model_D4 = Bidirectional(
139
            LSTM(
140
141
                self.LSTM_units_1,
                activation="tanh",
142
                recurrent_activation="sigmoid",
143
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
144
                # kernel_constraint=UnitNorm(axis=1),
145
146
            )
        )
147
        Model_D5 = Bidirectional(
148
            LSTM(
149
150
                self.LSTM_units_1,
                activation="sigmoid",
151
                recurrent_activation="sigmoid",
152
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
153
                # kernel_constraint=UnitNorm(axis=1),
154
155
156
            )
        )

lucas_miranda's avatar
lucas_miranda committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        )

172
173
174
175
    def build(
        self,
        input_shape: tuple,
    ) -> Tuple[Any, Any, Any]:
lucas_miranda's avatar
lucas_miranda committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        """Builds the tf.keras model"""

        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_E5,
            Model_D0,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
        ) = self.get_layers(input_shape)

193
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
194
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
lucas_miranda's avatar
lucas_miranda committed
195
        encoder.add(Input(shape=input_shape[1:]))
196
        encoder.add(Model_E0)
197
        encoder.add(BatchNormalization())
198
        encoder.add(Model_E1)
199
        encoder.add(BatchNormalization())
200
        encoder.add(Model_E2)
201
        encoder.add(BatchNormalization())
202
        encoder.add(Model_E3)
203
        encoder.add(BatchNormalization())
204
        encoder.add(Dropout(self.DROPOUT_RATE))
205
        encoder.add(Model_E4)
206
        encoder.add(BatchNormalization())
207
208
        encoder.add(Model_E5)

209
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
210
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
211
        decoder.add(Model_D0)
212
        decoder.add(BatchNormalization())
213
        decoder.add(Model_D1)
214
        decoder.add(BatchNormalization())
215
        decoder.add(Model_D2)
216
        decoder.add(BatchNormalization())
217
        decoder.add(Model_D3)
218
        decoder.add(Model_D4)
219
        decoder.add(BatchNormalization())
220
        decoder.add(Model_D5)
lucas_miranda's avatar
lucas_miranda committed
221
        decoder.add(TimeDistributed(Dense(input_shape[2])))
222

lucas_miranda's avatar
lucas_miranda committed
223
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
224
225

        model.compile(
lucas_miranda's avatar
lucas_miranda committed
226
            loss=Huber(delta=self.delta),
227
228
229
230
            optimizer=Nadam(
                lr=self.learn_rate,
                clipvalue=0.5,
            ),
231
232
233
            metrics=["mae"],
        )

234
235
        model.build(input_shape)

lucas_miranda's avatar
lucas_miranda committed
236
        return encoder, decoder, model
237
238


lucas_miranda's avatar
lucas_miranda committed
239
# noinspection PyDefaultArgument
240
class SEQ_2_SEQ_GMVAE:
241
    """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
lucas_miranda's avatar
lucas_miranda committed
242

243
    def __init__(
244
        self,
lucas_miranda's avatar
lucas_miranda committed
245
        architecture_hparams: dict = {},
246
        batch_size: int = 256,
247
        compile_model: bool = True,
248
        encoding: int = 16,
249
        entropy_reg_weight: float = 0.0,
lucas_miranda's avatar
lucas_miranda committed
250
        initialiser_iters: int = int(1),
251
        kl_warmup_epochs: int = 20,
252
        loss: str = "ELBO",
253
        mmd_warmup_epochs: int = 20,
254
        montecarlo_kl: int = 1,
255
        neuron_control: bool = False,
lucas_miranda's avatar
lucas_miranda committed
256
        number_of_components: int = 1,
257
        overlap_loss: float = False,
lucas_miranda's avatar
lucas_miranda committed
258
        phenotype_prediction: float = 0.0,
259
        predictor: float = 0.0,
260
261
        reg_cat_clusters: bool = False,
        reg_cluster_variance: bool = False,
262
    ):
lucas_miranda's avatar
lucas_miranda committed
263
        self.hparams = self.get_hparams(architecture_hparams)
264
        self.batch_size = batch_size
265
        self.bidirectional_merge = self.hparams["bidirectional_merge"]
lucas_miranda's avatar
lucas_miranda committed
266
267
268
269
        self.CONV_filters = self.hparams["units_conv"]
        self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
270
        self.ENCODING = encoding
lucas_miranda's avatar
lucas_miranda committed
271
272
273
274
        self.LSTM_units_1 = self.hparams["units_lstm"]
        self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
        self.clipvalue = self.hparams["clipvalue"]
        self.dense_activation = self.hparams["dense_activation"]
275
        self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
lucas_miranda's avatar
lucas_miranda committed
276
        self.learn_rate = self.hparams["learning_rate"]
lucas_miranda's avatar
lucas_miranda committed
277
        self.lstm_unroll = True
278
        self.compile = compile_model
279
280
        self.entropy_reg_weight = entropy_reg_weight
        self.initialiser_iters = initialiser_iters
281
        self.kl_warmup = kl_warmup_epochs
282
        self.loss = loss
283
        self.mc_kl = montecarlo_kl
284
        self.mmd_warmup = mmd_warmup_epochs
285
        self.neuron_control = neuron_control
286
        self.number_of_components = number_of_components
287
        self.overlap_loss = overlap_loss
lucas_miranda's avatar
lucas_miranda committed
288
        self.phenotype_prediction = phenotype_prediction
289
290
        self.predictor = predictor
        self.prior = "standard_normal"
291
292
        self.reg_cat_clusters = reg_cat_clusters
        self.reg_cluster_variance = reg_cluster_variance
293

lucas_miranda's avatar
lucas_miranda committed
294
295
296
297
298
299
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    @property
    def prior(self):
300
301
302
        """Property to set the value of the prior
        once the class is instanciated"""

lucas_miranda's avatar
lucas_miranda committed
303
304
305
306
307
        return self._prior

    def get_prior(self):
        """Sets the Variational Autoencoder prior distribution"""

308
        if self.prior == "standard_normal":
309
310
311
312
313
314
315
316
317
            # init_means = deepof.model_utils.far_away_uniform_initialiser(
            #     shape=(self.number_of_components, self.ENCODING),
            #     minval=0,
            #     maxval=5,
            #     iters=self.initialiser_iters,
            # )

            self.prior = tfd.MixtureSameFamily(
                mixture_distribution=tfd.categorical.Categorical(
318
319
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
320
321
                components_distribution=tfd.MultivariateNormalDiag(
                    loc=tf.Variable(
322
323
324
325
                        tf.random.normal(
                            [self.number_of_components, self.ENCODING],
                            name="prior_means",
                        )
326
327
328
329
                    ),
                    scale_diag=tfp.util.TransformedVariable(
                        tf.ones([self.number_of_components, self.ENCODING]),
                        tfb.Softplus(),
330
                        name="prior_scales",
331
332
                    ),
                ),
333
            )
334

335
        else:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
336
337
338
339
340
341
342
343
344
            raise NotImplementedError(
                "Gaussian Mixtures are currently the only supported prior"
            )

    @staticmethod
    def get_hparams(params: Dict) -> Dict:
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
lucas_miranda's avatar
lucas_miranda committed
345
            "bidirectional_merge": "ave",
lucas_miranda's avatar
lucas_miranda committed
346
347
            "clipvalue": 1.0,
            "dense_activation": "relu",
348
            "dense_layers_per_branch": 1,
lucas_miranda's avatar
lucas_miranda committed
349
            "dropout_rate": 1e-3,
lucas_miranda's avatar
lucas_miranda committed
350
            "learning_rate": 1e-3,
lucas_miranda's avatar
lucas_miranda committed
351
352
353
            "units_conv": 160,
            "units_dense2": 120,
            "units_lstm": 300,
lucas_miranda's avatar
lucas_miranda committed
354
355
356
357
358
359
360
361
362
        }

        for k, v in params.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
363
364
365
366
367
368
369

        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
lucas_miranda's avatar
lucas_miranda committed
370
            activation=self.dense_activation,
371
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
372
            use_bias=True,
373
374
375
376
377
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
378
                recurrent_activation="sigmoid",
379
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
380
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
381
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
382
                use_bias=True,
383
384
            ),
            merge_mode=self.bidirectional_merge,
385
386
387
388
389
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
390
                recurrent_activation="sigmoid",
391
                return_sequences=False,
lucas_miranda's avatar
lucas_miranda committed
392
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
393
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
394
                use_bias=True,
395
396
            ),
            merge_mode=self.bidirectional_merge,
397
398
399
        )
        Model_E3 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
400
401
            activation=self.dense_activation,
            # kernel_constraint=UnitNorm(axis=0),
402
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
403
            use_bias=True,
404
        )
405
406
407
408
409
410
411
412
413

        Model_E4 = [
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                # kernel_constraint=UnitNorm(axis=0),
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
414
            for _ in range(self.dense_layers_per_branch)
415
        ]
416
417
418
419
420
421

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
422
423
424
425
426
427
428
        Model_D1 = [
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
429
            for _ in range(self.dense_layers_per_branch)
430
        ]
431
        Model_D2 = Dense(
432
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
433
            activation=self.dense_activation,
434
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
435
            use_bias=True,
436
        )
lucas_miranda's avatar
lucas_miranda committed
437
        Model_D3 = RepeatVector(input_shape[1])
438
439
        Model_D4 = Bidirectional(
            LSTM(
440
                self.LSTM_units_2,
441
                activation="tanh",
442
                recurrent_activation="sigmoid",
443
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
444
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
445
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
446
                use_bias=True,
447
448
            ),
            merge_mode=self.bidirectional_merge,
449
450
451
452
453
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
454
                recurrent_activation="sigmoid",
455
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
456
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
457
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
458
                use_bias=True,
459
460
            ),
            merge_mode=self.bidirectional_merge,
461
        )
lucas_miranda's avatar
lucas_miranda committed
462
463

        # Predictor layers
lucas_miranda's avatar
lucas_miranda committed
464
465
        Model_P1 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
466
            activation=self.dense_activation,
lucas_miranda's avatar
lucas_miranda committed
467
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
468
            use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
469
470
471
472
473
474
475
        )
        Model_P2 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
476
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
477
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
478
                use_bias=True,
479
480
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
481
482
483
484
485
486
487
        )
        Model_P3 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
lucas_miranda's avatar
lucas_miranda committed
488
                unroll=self.lstm_unroll,
lucas_miranda's avatar
lucas_miranda committed
489
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
490
                use_bias=True,
491
492
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
493
        )
494

lucas_miranda's avatar
lucas_miranda committed
495
        # Phenotype classification layers
496
        Model_PC1 = Dense(
lucas_miranda's avatar
lucas_miranda committed
497
498
499
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
500
        )
lucas_miranda's avatar
lucas_miranda committed
501

lucas_miranda's avatar
lucas_miranda committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
517
518
519
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
520
            Model_PC1,
lucas_miranda's avatar
lucas_miranda committed
521
522
523
        )

    def build(self, input_shape: Tuple):
524
        """Builds the tf.keras model"""
lucas_miranda's avatar
lucas_miranda committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544

        # Instanciate prior
        self.get_prior()

        # Get model layers
        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_B1,
            Model_B2,
            Model_B3,
            Model_B4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
lucas_miranda's avatar
lucas_miranda committed
545
546
547
            Model_P1,
            Model_P2,
            Model_P3,
lucas_miranda's avatar
lucas_miranda committed
548
            Model_PC1,
lucas_miranda's avatar
lucas_miranda committed
549
550
        ) = self.get_layers(input_shape)

551
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
552
        x = Input(shape=input_shape[1:])
553
554
555
556
557
558
559
560
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
561
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
562
563
        # encoder = Sequential(Model_E4)(encoder)
        # encoder = BatchNormalization()(encoder)
564

565
        # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
566
567
568
        z_cat = Dense(
            self.number_of_components,
            activation="softmax",
569
570
571
572
573
            kernel_regularizer=(
                tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
                if self.reg_cat_clusters
                else None
            ),
574
        )(encoder)
575
576

        if self.entropy_reg_weight > 0:
577
578
579
            z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(
                z_cat
            )
580

581
        z_gauss_mean = Dense(
lucas_miranda's avatar
lucas_miranda committed
582
            tfpl.IndependentNormal.params_size(
583
                self.ENCODING * self.number_of_components
584
585
586
            )
            // 2,
            activation=None,
587
            initializer=Orthogonal(),
588
589
590
591
592
593
594
        )(encoder)

        z_gauss_var = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            )
            // 2,
595
            activation=None,
596
597
598
599
600
601
            activity_regularizer=(
                tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
            ),
        )(encoder)

        z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
602

603
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
604

lucas_miranda's avatar
lucas_miranda committed
605
        # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
606
607
        if self.neuron_control:
            z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
lucas_miranda's avatar
lucas_miranda committed
608

609
        if self.overlap_loss:
610
            z_gauss = deepof.model_utils.Gaussian_mixture_overlap(
611
612
613
                self.ENCODING,
                self.number_of_components,
                loss=self.overlap_loss,
614
            )(z_gauss)
615

lucas_miranda's avatar
lucas_miranda committed
616
        z = tfpl.DistributionLambda(
lucas_miranda's avatar
lucas_miranda committed
617
            lambda gauss: tfd.mixture.Mixture(
618
619
620
                cat=tfd.categorical.Categorical(
                    probs=gauss[0],
                ),
621
                components=[
lucas_miranda's avatar
lucas_miranda committed
622
623
                    tfd.Independent(
                        tfd.Normal(
624
                            loc=gauss[1][..., : self.ENCODING, k],
lucas_miranda's avatar
lucas_miranda committed
625
                            scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
626
627
628
629
630
631
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
632
            convert_to_tensor_fn="sample",
633
        )([z_cat, z_gauss])
634

635
636
        # Define and control custom loss functions
        kl_warmup_callback = False
637
        if "ELBO" in self.loss:
638

639
            kl_beta = deepof.model_utils.K.variable(1.0, name="kl_beta")
640
641
642
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
643
644
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        kl_beta, deepof.model_utils.K.min([epoch / self.kl_warmup, 1])
645
646
647
                    )
                )

648
            # noinspection PyCallingNonCallable
649
650
651
            z = deepof.model_utils.KLDivergenceLayer(
                self.prior,
                test_points_fn=lambda q: q.sample(self.mc_kl),
652
                test_points_reduce_axis=0,
653
654
                weight=kl_beta,
            )(z)
655
656
657
658

        mmd_warmup_callback = False
        if "MMD" in self.loss:

659
            mmd_beta = deepof.model_utils.K.variable(1.0, name="mmd_beta")
660
661
662
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
663
664
                    on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
                        mmd_beta, deepof.model_utils.K.min([epoch / self.mmd_warmup, 1])
665
666
667
                    )
                )

668
            z = deepof.model_utils.MMDiscrepancyLayer(
669
670
                batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
            )(z)
671
672

        # Define and instantiate generator
lucas_miranda's avatar
lucas_miranda committed
673
        g = Input(shape=self.ENCODING)
674
675
676
        # generator = Sequential(Model_D1)(g)
        # generator = Model_B1(generator)
        generator = Model_D2(g)
677
        generator = Model_B2(generator)
678
679
        generator = Model_D3(generator)
        generator = Model_D4(generator)
680
        generator = Model_B3(generator)
681
        generator = Model_D5(generator)
682
        generator = Model_B4(generator)
683
684
685
        generator = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
            generator
        )
686
        x_decoded_mean = tfpl.IndependentNormal(
687
688
689
            event_shape=input_shape[2:],
            convert_to_tensor_fn=tfp.distributions.Distribution.mean,
            name="vae_reconstruction",
690
691
        )(generator)

lucas_miranda's avatar
lucas_miranda committed
692
693
694
695
        # define individual branches as models
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
        generator = Model(g, x_decoded_mean, name="vae_reconstruction")

696
697
698
        def log_loss(x_true, p_x_q_given_z):
            """Computes the negative log likelihood of the data given
            the output distribution"""
699
            return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))
700

lucas_miranda's avatar
lucas_miranda committed
701
        model_outs = [generator(encoder.outputs)]
702
703
        model_losses = [log_loss]
        model_metrics = {"vae_reconstruction": ["mae", "mse"]}
704
        loss_weights = [1.0]
lucas_miranda's avatar
lucas_miranda committed
705

706
        if self.predictor > 0:
707
708
            # Define and instantiate predictor
            predictor = Dense(
lucas_miranda's avatar
lucas_miranda committed
709
710
711
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
712
713
            )(z)
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
714
            predictor = Model_P1(predictor)
715
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
716
            predictor = RepeatVector(input_shape[1])(predictor)
lucas_miranda's avatar
lucas_miranda committed
717
            predictor = Model_P2(predictor)
718
            predictor = BatchNormalization()(predictor)
lucas_miranda's avatar
lucas_miranda committed
719
            predictor = Model_P3(predictor)
720
            predictor = BatchNormalization()(predictor)
721
722
723
724
725
726
727
            predictor = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
                predictor
            )
            x_predicted_mean = tfpl.IndependentNormal(
                event_shape=input_shape[2:],
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="vae_prediction",
728
            )(predictor)
729

lucas_miranda's avatar
lucas_miranda committed
730
            model_outs.append(x_predicted_mean)
731
            model_losses.append(log_loss)
732
            model_metrics["vae_prediction"] = ["mae", "mse"]
lucas_miranda's avatar
lucas_miranda committed
733
734
735
736
            loss_weights.append(self.predictor)

        if self.phenotype_prediction > 0:
            pheno_pred = Model_PC1(z)
737
738
739
740
741
742
            pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
            pheno_pred = tfpl.IndependentBernoulli(
                event_shape=1,
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="phenotype_prediction",
            )(pheno_pred)
lucas_miranda's avatar
lucas_miranda committed
743
744

            model_outs.append(pheno_pred)
745
            model_losses.append(log_loss)
746
            model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
lucas_miranda's avatar
lucas_miranda committed
747
748
            loss_weights.append(self.phenotype_prediction)

lucas_miranda's avatar
lucas_miranda committed
749
750
        # define grouper and end-to-end autoencoder model
        grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
751
        gmvaep = Model(
lucas_miranda's avatar
lucas_miranda committed
752
            inputs=encoder.inputs,
753
754
755
            outputs=model_outs,
            name="SEQ_2_SEQ_GMVAE",
        )
756

757
758
759
        if self.compile:
            gmvaep.compile(
                loss=model_losses,
760
761
762
763
                optimizer=Nadam(
                    lr=self.learn_rate,
                    clipvalue=self.clipvalue,
                ),
764
                metrics=model_metrics,
765
766
                loss_weights=loss_weights,
            )
767

768
769
        gmvaep.build(input_shape)

770
771
772
773
774
775
776
777
        return (
            encoder,
            generator,
            grouper,
            gmvaep,
            kl_warmup_callback,
            mmd_warmup_callback,
        )
lucas_miranda's avatar
lucas_miranda committed
778

lucas_miranda's avatar
lucas_miranda committed
779
780
781
782
    @prior.setter
    def prior(self, value):
        self._prior = value

783

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# noinspection PyDefaultArgument
# class SEQ_2_SEQ_CONV_GMVAE:
#     """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
#
#     def __init__(
#         self,
#         architecture_hparams: dict = {},
#         batch_size: int = 256,
#         compile_model: bool = True,
#         encoding: int = 16,
#         entropy_reg_weight: float = 0.0,
#         initialiser_iters: int = int(1),
#         kl_warmup_epochs: int = 20,
#         loss: str = "ELBO",
#         mmd_warmup_epochs: int = 20,
#         montecarlo_kl: int = 1,
#         neuron_control: bool = False,
#         number_of_components: int = 1,
#         overlap_loss: float = False,
#         phenotype_prediction: float = 0.0,
#         predictor: float = 0.0,
#         reg_cat_clusters: bool = False,
#         reg_cluster_variance: bool = False,
#     ):
#         self.hparams = self.get_hparams(architecture_hparams)
#         self.batch_size = batch_size
#         self.bidirectional_merge = self.hparams["bidirectional_merge"]
#         self.CONV_filters = self.hparams["units_conv"]
#         self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
#         self.DENSE_2 = self.hparams["units_dense2"]
#         self.DROPOUT_RATE = self.hparams["dropout_rate"]
#         self.ENCODING = encoding
#         self.LSTM_units_1 = self.hparams["units_lstm"]
#         self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
#         self.clipvalue = self.hparams["clipvalue"]
#         self.dense_activation = self.hparams["dense_activation"]
#         self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
#         self.learn_rate = self.hparams["learning_rate"]
#         self.lstm_unroll = True
#         self.compile = compile_model
#         self.entropy_reg_weight = entropy_reg_weight
#         self.initialiser_iters = initialiser_iters
#         self.kl_warmup = kl_warmup_epochs
#         self.loss = loss
#         self.mc_kl = montecarlo_kl
#         self.mmd_warmup = mmd_warmup_epochs
#         self.neuron_control = neuron_control
#         self.number_of_components = number_of_components
#         self.overlap_loss = overlap_loss
#         self.phenotype_prediction = phenotype_prediction
#         self.predictor = predictor
#         self.prior = "standard_normal"
#         self.reg_cat_clusters = reg_cat_clusters
#         self.reg_cluster_variance = reg_cluster_variance
#
#         assert (
#             "ELBO" in self.loss or "MMD" in self.loss
#         ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
#
#     @property
#     def prior(self):
#         """Property to set the value of the prior
#         once the class is instanciated"""
#
#         return self._prior
#
#     def get_prior(self):
#         """Sets the Variational Autoencoder prior distribution"""
#
#         if self.prior == "standard_normal":
#             # init_means = deepof.model_utils.far_away_uniform_initialiser(
#             #     shape=(self.number_of_components, self.ENCODING),
#             #     minval=0,
#             #     maxval=5,
#             #     iters=self.initialiser_iters,
#             # )
#
#             self.prior = tfd.MixtureSameFamily(
#                 mixture_distribution=tfd.categorical.Categorical(
#                     probs=tf.ones(self.number_of_components) / self.number_of_components
#                 ),
#                 components_distribution=tfd.MultivariateNormalDiag(
#                     loc=tf.Variable(
#                         tf.random.normal(
#                             [self.number_of_components, self.ENCODING],
#                             name="prior_means",
#                         )
#                     ),
#                     scale_diag=tfp.util.TransformedVariable(
#                         tf.ones([self.number_of_components, self.ENCODING]),
#                         tfb.Softplus(),
#                         name="prior_scales",
#                     ),
#                 ),
#             )
#
#         else:  # pragma: no cover
#             raise NotImplementedError(
#                 "Gaussian Mixtures are currently the only supported prior"
#             )
#
#     @staticmethod
#     def get_hparams(params: Dict) -> Dict:
#         """Sets the default parameters for the model. Overwritable with a dictionary"""
#
#         defaults = {
#             "bidirectional_merge": "ave",
#             "clipvalue": 1.0,
#             "dense_activation": "relu",
#             "dense_layers_per_branch": 1,
#             "dropout_rate": 1e-3,
#             "learning_rate": 1e-3,
#             "units_conv": 160,
#             "units_dense2": 120,
#             "units_conv2": 300,
#         }
#
#         for k, v in params.items():
#             defaults[k] = v
#
#         return defaults
#
#     def get_layers(self, input_shape):
#         """Instanciate all layers in the model"""
#
#         # Encoder Layers
#
#         # Decoder layers
#
#         # Predictor layers
#
#         # Phenotype classification layers
#
#         pass
#
#     def build(self, input_shape: Tuple):
#         """Builds the tf.keras model"""
#
#         # Instanciate prior
#         self.get_prior()
#
#         # Get model layers
#         () = self.get_layers(input_shape)
#
#         # Define and instantiate encoder
#         x = Input(shape=input_shape[1:])
#         encoder = 0
#
#         # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
#         z_cat = Dense(
#             self.number_of_components,
#             activation="softmax",
#             kernel_regularizer=(
#                 tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
#                 if self.reg_cat_clusters
#                 else None
#             ),
#         )(encoder)
#
#         if self.entropy_reg_weight > 0:
#             z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(
#                 z_cat
#             )
#
#         z_gauss_mean = Dense(
#             tfpl.IndependentNormal.params_size(
#                 self.ENCODING * self.number_of_components
#             )
#             // 2,
#             activation=None,
#         )(encoder)
#
#         z_gauss_var = Dense(
#             tfpl.IndependentNormal.params_size(
#                 self.ENCODING * self.number_of_components
#             )
#             // 2,
#             activation=None,
#             activity_regularizer=(
#                 tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
#             ),
#         )(encoder)
#
#         z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
#
#         z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
#
#         # Identity layer controlling for dead neurons in the Gaussian Mixture posterior
#         if self.neuron_control:
#             z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
#
#         if self.overlap_loss:
#             z_gauss = deepof.model_utils.Gaussian_mixture_overlap(
#                 self.ENCODING,
#                 self.number_of_components,
#                 loss=self.overlap_loss,
#             )(z_gauss)
#
#         z = tfpl.DistributionLambda(
#             lambda gauss: tfd.mixture.Mixture(
#                 cat=tfd.categorical.Categorical(
#                     probs=gauss[0],
#                 ),
#                 components=[
#                     tfd.Independent(
#                         tfd.Normal(
#                             loc=gauss[1][..., : self.ENCODING, k],
#                             scale=softplus(gauss[1][..., self.ENCODING :, k]) + 1e-5,
#                         ),
#                         reinterpreted_batch_ndims=1,
#                     )
#                     for k in range(self.number_of_components)
#                 ],
#             ),
#             convert_to_tensor_fn="sample",
#         )([z_cat, z_gauss])
#
For faster browsing, not all history is shown. View entire blame