models.py 21.9 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

deep autoencoder models for unsupervised pose detection

"""
10

lucas_miranda's avatar
lucas_miranda committed
11
from typing import Any, Dict, Tuple
lucas_miranda's avatar
lucas_miranda committed
12
13
14

import tensorflow as tf
import tensorflow_probability as tfp
15
from tensorflow.keras import Input, Model, Sequential
16
from tensorflow.keras.activations import softplus
17
from tensorflow.keras.constraints import UnitNorm
18
from tensorflow.keras.initializers import he_uniform, random_uniform
19
from tensorflow.keras.layers import BatchNormalization, Bidirectional
20
from tensorflow.keras.layers import Dense, Dropout, GRU
21
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
lucas_miranda's avatar
lucas_miranda committed
22
from tensorflow.keras.losses import Huber
23
from tensorflow.keras.optimizers import Nadam
lucas_miranda's avatar
lucas_miranda committed
24

25
import deepof.model_utils
26

27
tfb = tfp.bijectors
28
29
tfd = tfp.distributions
tfpl = tfp.layers
30

lucas_miranda's avatar
lucas_miranda committed
31
# noinspection PyDefaultArgument
32
class GMVAE:
33
    """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
lucas_miranda's avatar
lucas_miranda committed
34

35
    def __init__(
36
37
38
39
40
        self,
        architecture_hparams: dict = {},
        batch_size: int = 256,
        compile_model: bool = True,
        encoding: int = 6,
41
        kl_annealing_mode: str = "sigmoid",
42
43
        kl_warmup_epochs: int = 20,
        loss: str = "ELBO",
44
        mmd_annealing_mode: str = "sigmoid",
45
        mmd_warmup_epochs: int = 20,
46
        montecarlo_kl: int = 10,
47
48
        number_of_components: int = 1,
        overlap_loss: float = 0.0,
49
        next_sequence_prediction: float = 0.0,
50
        phenotype_prediction: float = 0.0,
51
52
        rule_based_prediction: float = 0.0,
        rule_based_features: int = 6,
53
54
        reg_cat_clusters: bool = False,
        reg_cluster_variance: bool = False,
55
    ):
lucas_miranda's avatar
lucas_miranda committed
56
        self.hparams = self.get_hparams(architecture_hparams)
57
        self.batch_size = batch_size
58
        self.bidirectional_merge = self.hparams["bidirectional_merge"]
lucas_miranda's avatar
lucas_miranda committed
59
        self.CONV_filters = self.hparams["units_conv"]
60
        self.DENSE_1 = int(self.hparams["units_gru"] / 2)
lucas_miranda's avatar
lucas_miranda committed
61
62
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
63
        self.ENCODING = encoding
64
65
        self.GRU_units_1 = self.hparams["units_gru"]
        self.GRU_units_2 = int(self.hparams["units_gru"] / 2)
lucas_miranda's avatar
lucas_miranda committed
66
67
        self.clipvalue = self.hparams["clipvalue"]
        self.dense_activation = self.hparams["dense_activation"]
68
        self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
lucas_miranda's avatar
lucas_miranda committed
69
        self.learn_rate = self.hparams["learning_rate"]
70
        self.gru_unroll = True
71
        self.compile = compile_model
72
        self.kl_annealing_mode = kl_annealing_mode
73
        self.kl_warmup = kl_warmup_epochs
74
        self.loss = loss
75
        self.mc_kl = montecarlo_kl
76
        self.mmd_annealing_mode = mmd_annealing_mode
77
        self.mmd_warmup = mmd_warmup_epochs
78
        self.number_of_components = number_of_components
79
        self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue)
80
        self.overlap_loss = overlap_loss
81
        self.next_sequence_prediction = next_sequence_prediction
lucas_miranda's avatar
lucas_miranda committed
82
        self.phenotype_prediction = phenotype_prediction
83
84
        self.rule_based_prediction = rule_based_prediction
        self.rule_based_features = rule_based_features
85
        self.prior = "standard_normal"
86
87
        self.reg_cat_clusters = reg_cat_clusters
        self.reg_cluster_variance = reg_cluster_variance
88

lucas_miranda's avatar
lucas_miranda committed
89
        assert (
90
            "ELBO" in self.loss or "MMD" in self.loss
lucas_miranda's avatar
lucas_miranda committed
91
92
93
94
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    @property
    def prior(self):
95
96
97
        """Property to set the value of the prior
        once the class is instanciated"""

lucas_miranda's avatar
lucas_miranda committed
98
99
100
101
102
        return self._prior

    def get_prior(self):
        """Sets the Variational Autoencoder prior distribution"""

103
        if self.prior == "standard_normal":
104
105
106

            self.prior = tfd.MixtureSameFamily(
                mixture_distribution=tfd.categorical.Categorical(
107
108
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
109
110
                components_distribution=tfd.MultivariateNormalDiag(
                    loc=tf.Variable(
lucas_miranda's avatar
lucas_miranda committed
111
                        he_uniform()(
112
                            [self.number_of_components, self.ENCODING],
113
114
                        ),
                        name="prior_means",
115
116
                    ),
                    scale_diag=tfp.util.TransformedVariable(
lucas_miranda's avatar
lucas_miranda committed
117
118
                        tf.ones([self.number_of_components, self.ENCODING])
                        / self.number_of_components,
119
                        tfb.Softplus(),
120
                        name="prior_scales",
121
122
                    ),
                ),
123
            )
124

125
        else:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
126
127
128
129
130
131
132
133
134
            raise NotImplementedError(
                "Gaussian Mixtures are currently the only supported prior"
            )

    @staticmethod
    def get_hparams(params: Dict) -> Dict:
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
135
            "bidirectional_merge": "concat",
136
            "clipvalue": 0.75,
lucas_miranda's avatar
lucas_miranda committed
137
            "dense_activation": "relu",
138
            "dense_layers_per_branch": 1,
139
140
            "dropout_rate": 0.1,
            "learning_rate": 1e-4,
141
            "units_conv": 64,
142
            "units_dense2": 32,
143
            "units_gru": 128,
lucas_miranda's avatar
lucas_miranda committed
144
145
146
147
148
149
150
151
152
        }

        for k, v in params.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
153
154
155
156
157

        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
158
            strides=2,  # Increased strides to yield shorter sequences
159
            padding="same",
lucas_miranda's avatar
lucas_miranda committed
160
            activation=self.dense_activation,
161
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
162
            use_bias=True,
163
164
        )
        Model_E1 = Bidirectional(
165
166
            GRU(
                self.GRU_units_1,
167
                activation="tanh",
168
                recurrent_activation="sigmoid",
169
                return_sequences=True,
170
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
171
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
172
                use_bias=True,
173
174
            ),
            merge_mode=self.bidirectional_merge,
175
176
        )
        Model_E2 = Bidirectional(
177
178
            GRU(
                self.GRU_units_2,
179
                activation="tanh",
180
                recurrent_activation="sigmoid",
181
                return_sequences=False,
182
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
183
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
184
                use_bias=True,
185
186
            ),
            merge_mode=self.bidirectional_merge,
187
188
189
        )
        Model_E3 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
190
191
            activation=self.dense_activation,
            # kernel_constraint=UnitNorm(axis=0),
192
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
193
            use_bias=True,
194
        )
195

196
        seq_E = [
197
198
199
200
201
202
203
204
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                # kernel_constraint=UnitNorm(axis=0),
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
            for _ in range(self.dense_layers_per_branch)
205
        ]
206
207
208
209
        Model_E4 = []
        for l in seq_E:
            Model_E4.append(l)
            Model_E4.append(BatchNormalization())
210
211

        # Decoder layers
212
        seq_D = [
213
214
215
216
217
218
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
219
            for _ in range(self.dense_layers_per_branch)
220
        ]
221
222
223
224
225
        Model_D1 = []
        for l in seq_D:
            Model_D1.append(l)
            Model_D1.append(BatchNormalization())

226
        Model_D2 = Dense(
227
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
228
            activation=self.dense_activation,
229
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
230
            use_bias=True,
231
        )
lucas_miranda's avatar
lucas_miranda committed
232
        Model_D3 = RepeatVector(input_shape[1])
233
        Model_D4 = Bidirectional(
234
235
            GRU(
                self.GRU_units_2,
236
                activation="tanh",
237
                recurrent_activation="sigmoid",
238
                return_sequences=True,
239
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
240
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
241
                use_bias=True,
242
243
            ),
            merge_mode=self.bidirectional_merge,
244
245
        )
        Model_D5 = Bidirectional(
246
247
            GRU(
                self.GRU_units_1,
248
                activation="tanh",
249
                recurrent_activation="sigmoid",
250
                return_sequences=True,
251
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
252
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
253
                use_bias=True,
254
255
            ),
            merge_mode=self.bidirectional_merge,
256
        )
257
258
259
260
261
262
263
264
265
        Model_D6 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="same",
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
            use_bias=True,
        )
lucas_miranda's avatar
lucas_miranda committed
266
267

        # Predictor layers
lucas_miranda's avatar
lucas_miranda committed
268
269
        Model_P1 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
270
            activation=self.dense_activation,
lucas_miranda's avatar
lucas_miranda committed
271
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
272
            use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
273
274
        )
        Model_P2 = Bidirectional(
275
276
            GRU(
                self.GRU_units_1,
lucas_miranda's avatar
lucas_miranda committed
277
278
279
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
280
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
281
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
282
                use_bias=True,
283
284
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
285
286
        )
        Model_P3 = Bidirectional(
287
288
            GRU(
                self.GRU_units_1,
lucas_miranda's avatar
lucas_miranda committed
289
290
291
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
292
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
293
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
294
                use_bias=True,
295
296
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
297
        )
298
299
300
301
302
303
304
305
306
        Model_P4 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="same",
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
            use_bias=True,
        )
307

308
        # Phenotype classification layer
309
        Model_PC1 = Dense(
lucas_miranda's avatar
lucas_miranda committed
310
311
312
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
313
        )
lucas_miranda's avatar
lucas_miranda committed
314

315
316
317
318
319
320
321
        # Rule based trait classification layer
        Model_RC1 = Dense(
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
        )

lucas_miranda's avatar
lucas_miranda committed
322
323
324
325
326
327
328
329
330
331
332
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
333
            Model_D6,
lucas_miranda's avatar
lucas_miranda committed
334
335
336
            Model_P1,
            Model_P2,
            Model_P3,
337
            Model_P4,
lucas_miranda's avatar
lucas_miranda committed
338
            Model_PC1,
339
            Model_RC1,
lucas_miranda's avatar
lucas_miranda committed
340
341
342
        )

    def build(self, input_shape: Tuple):
343
        """Builds the tf.keras model"""
lucas_miranda's avatar
lucas_miranda committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

        # Instanciate prior
        self.get_prior()

        # Get model layers
        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
360
            Model_D6,
lucas_miranda's avatar
lucas_miranda committed
361
362
363
            Model_P1,
            Model_P2,
            Model_P3,
364
            Model_P4,
lucas_miranda's avatar
lucas_miranda committed
365
            Model_PC1,
366
            Model_RC1,
lucas_miranda's avatar
lucas_miranda committed
367
368
        ) = self.get_layers(input_shape)

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        # Define and instantiate encoder
        x = Input(shape=input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Sequential(Model_E4)(encoder)

        # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
        z_cat = Dense(
            self.number_of_components,
            name="cluster_assignment",
            activation="softmax",
            activity_regularizer=(
                tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
                if self.reg_cat_clusters
                else None
            ),
        )(encoder)
393

394
395
396
397
398
399
400
        z_gauss_mean = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            )
            // 2,
            name="cluster_means",
            activation=None,
401
            activity_regularizer=(tf.keras.regularizers.l1(10e-5)),
lucas_miranda's avatar
lucas_miranda committed
402
            kernel_initializer=he_uniform(),
403
404
405
406
407
408
409
410
411
412
413
414
        )(encoder)

        z_gauss_var = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            )
            // 2,
            name="cluster_variances",
            activation=None,
            activity_regularizer=(
                tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
            ),
415
            kernel_initializer=random_uniform(),
416
        )(encoder)
417

418
        z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
419

420
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
421

422
423
424
425
        z = tfpl.DistributionLambda(
            make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(
                    probs=gauss[0],
426
                ),
427
428
429
430
                components=[
                    tfd.Independent(
                        tfd.Normal(
                            loc=gauss[1][..., : self.ENCODING, k],
431
                            scale=1e-3 + softplus(gauss[1][..., self.ENCODING :, k]),
432
433
434
435
436
437
438
439
440
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
            convert_to_tensor_fn="sample",
            name="encoding_distribution",
        )([z_cat, z_gauss])
441

442
        posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
443

444
445
446
447
448
449
        # Define and control custom loss functions
        if "ELBO" in self.loss:
            kl_warm_up_iters = tf.cast(
                self.kl_warmup * (input_shape[0] // self.batch_size + 1),
                tf.int64,
            )
450

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            # noinspection PyCallingNonCallable
            z = deepof.model_utils.KLDivergenceLayer(
                distribution_b=self.prior,
                test_points_fn=lambda q: q.sample(self.mc_kl),
                test_points_reduce_axis=0,
                iters=self.optimizer.iterations,
                warm_up_iters=kl_warm_up_iters,
                annealing_mode=self.kl_annealing_mode,
            )(z)

        if "MMD" in self.loss:
            mmd_warm_up_iters = tf.cast(
                self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
                tf.int64,
            )
466

467
468
469
470
471
472
473
474
475
476
477
            z = deepof.model_utils.MMDiscrepancyLayer(
                batch_size=self.batch_size,
                prior=self.prior,
                iters=self.optimizer.iterations,
                warm_up_iters=mmd_warm_up_iters,
                annealing_mode=self.mmd_annealing_mode,
            )(z)

        # Dummy layer with no parameters, to retrieve the previous tensor
        z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)

lucas_miranda's avatar
lucas_miranda committed
478
        if self.number_of_components > 1 and self.overlap_loss:
479
            z = deepof.model_utils.ClusterOverlap(
480
481
482
                batch_size=self.batch_size,
                encoding_dim=self.ENCODING,
                k=self.number_of_components,
483
484
                loss_weight=self.overlap_loss,
            )([z, z_cat])
485

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        # Define and instantiate generator
        g = Input(shape=self.ENCODING)
        generator = Sequential(Model_D1)(g)
        generator = Model_D2(generator)
        generator = BatchNormalization()(generator)
        generator = Model_D3(generator)
        generator = Model_D4(generator)
        generator = BatchNormalization()(generator)
        generator = Model_D5(generator)
        generator = BatchNormalization()(generator)
        generator = Model_D6(generator)
        generator = BatchNormalization()(generator)
        x_decoded_mean = Dense(
            tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
        )(generator)
        x_decoded_var = tf.keras.activations.softplus(
            Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
        )
        x_decoded_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(x_decoded_var)
        x_decoded = tf.keras.layers.concatenate(
            [x_decoded_mean, x_decoded_var], axis=-1
        )
        x_decoded_mean = tfpl.IndependentNormal(
            event_shape=input_shape[2:],
            convert_to_tensor_fn=tfp.distributions.Distribution.mean,
            name="vae_reconstruction",
        )(x_decoded)

        # define individual branches as models
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
        generator = Model(g, x_decoded_mean, name="vae_reconstruction")

        def log_loss(x_true, p_x_q_given_z):
            """Computes the negative log likelihood of the data given
            the output distribution"""
            return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))

        model_outs = [generator(encoder.outputs)]
        model_losses = [log_loss]
        model_metrics = {"vae_reconstruction": ["mae", "mse"]}
        loss_weights = [1.0]

        if self.next_sequence_prediction > 0:
            # Define and instantiate predictor
            predictor = Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
            )(z)
            predictor = BatchNormalization()(predictor)
            predictor = Model_P1(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = RepeatVector(input_shape[1])(predictor)
            predictor = Model_P2(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = Model_P3(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = Model_P4(predictor)
            x_predicted_mean = Dense(
545
                tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
546
547
548
549
550
551
552
553
            )(predictor)
            x_predicted_var = tf.keras.activations.softplus(
                Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
                    predictor
                )
            )
            x_predicted_var = tf.keras.layers.Lambda(lambda x: 1e-3 + x)(
                x_predicted_var
554
555
            )
            x_decoded = tf.keras.layers.concatenate(
556
                [x_predicted_mean, x_predicted_var], axis=-1
557
            )
558
            x_predicted_mean = tfpl.IndependentNormal(
559
560
                event_shape=input_shape[2:],
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
561
                name="vae_prediction",
562
            )(x_decoded)
563

564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
            model_outs.append(x_predicted_mean)
            model_losses.append(log_loss)
            model_metrics["vae_prediction"] = ["mae", "mse"]
            loss_weights.append(self.next_sequence_prediction)

        if self.phenotype_prediction > 0:
            pheno_pred = Model_PC1(z)
            pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
            pheno_pred = tfpl.IndependentBernoulli(
                event_shape=1,
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="phenotype_prediction",
            )(pheno_pred)

            model_outs.append(pheno_pred)
            model_losses.append(log_loss)
            model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
            loss_weights.append(self.phenotype_prediction)

        if self.rule_based_prediction > 0:
            rule_pred = Model_RC1(z)

            rule_pred = Dense(
                tfpl.IndependentBernoulli.params_size(self.rule_based_features)
            )(rule_pred)
            rule_pred = tfpl.IndependentBernoulli(
                event_shape=self.rule_based_features,
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="rule_based_prediction",
            )(rule_pred)

            model_outs.append(rule_pred)
            model_losses.append(log_loss)
            model_metrics["rule_based_prediction"] = [
                "mae",
                "mse",
            ]
            loss_weights.append(self.rule_based_prediction)
602

lucas_miranda's avatar
lucas_miranda committed
603
        # define grouper and end-to-end autoencoder model
604
        grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
605
        gmvaep = Model(
606
            inputs=encoder.inputs,
607
608
609
            outputs=model_outs,
            name="SEQ_2_SEQ_GMVAE",
        )
610

611
612
613
        if self.compile:
            gmvaep.compile(
                loss=model_losses,
614
                optimizer=self.optimizer,
615
                metrics=model_metrics,
616
617
                loss_weights=loss_weights,
            )
618

619
620
        gmvaep.build(input_shape)

621
        return (
622
            encoder,
623
624
625
            generator,
            grouper,
            gmvaep,
626
627
            self.prior,
            posterior,
628
        )
lucas_miranda's avatar
lucas_miranda committed
629

lucas_miranda's avatar
lucas_miranda committed
630
631
632
633
    @prior.setter
    def prior(self, value):
        self._prior = value

634

635
# TODO:
636
#       - Check usefulness of stateful sequential layers! (stateful=True in the GRUs)
637
#       - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
lucas_miranda's avatar
lucas_miranda committed
638
#       - Explore expanding the event dims of the final reconstruction layer
639
#       - Think about gradient penalty to avoid mode collapse (as in WGAN-GP)
640
#       - Think about using spectral normalization
641
642
#       - REVISIT DROPOUT - CAN HELP WITH TRAINING STABILIZATION
#       - Decrease learning rate!
643
#       - Implement residual blocks!