models.py 21.8 KB
Newer Older
1
# @author lucasmiranda42
2
3
4
5
6
7
8
9
# encoding: utf-8
# module deepof

"""

deep autoencoder models for unsupervised pose detection

"""
10

11
from typing import Dict, Tuple
lucas_miranda's avatar
lucas_miranda committed
12
13
14

import tensorflow as tf
import tensorflow_probability as tfp
15
from tensorflow.keras import Input, Model, Sequential
16
from tensorflow.keras.activations import softplus
17
from tensorflow.keras.initializers import he_uniform, random_uniform
18
from tensorflow.keras.layers import BatchNormalization, Bidirectional
19
from tensorflow.keras.layers import Dense, Dropout, GRU
20
from tensorflow.keras.layers import RepeatVector, Reshape
21
from tensorflow.keras.optimizers import Nadam
lucas_miranda's avatar
lucas_miranda committed
22

23
import deepof.model_utils
24

25
tfb = tfp.bijectors
26
27
tfd = tfp.distributions
tfpl = tfp.layers
28

29

lucas_miranda's avatar
lucas_miranda committed
30
# noinspection PyDefaultArgument
31
class GMVAE:
32
    """  Gaussian Mixture Variational Autoencoder for pose motif elucidation.  """
lucas_miranda's avatar
lucas_miranda committed
33

34
    def __init__(
35
36
37
38
39
        self,
        architecture_hparams: dict = {},
        batch_size: int = 256,
        compile_model: bool = True,
        encoding: int = 6,
40
        kl_annealing_mode: str = "sigmoid",
41
42
        kl_warmup_epochs: int = 20,
        loss: str = "ELBO",
43
        mmd_annealing_mode: str = "sigmoid",
44
        mmd_warmup_epochs: int = 20,
45
        montecarlo_kl: int = 10,
46
47
        number_of_components: int = 1,
        overlap_loss: float = 0.0,
48
        next_sequence_prediction: float = 0.0,
49
        phenotype_prediction: float = 0.0,
50
51
        rule_based_prediction: float = 0.0,
        rule_based_features: int = 6,
52
53
        reg_cat_clusters: bool = False,
        reg_cluster_variance: bool = False,
54
    ):
lucas_miranda's avatar
lucas_miranda committed
55
        self.hparams = self.get_hparams(architecture_hparams)
56
        self.batch_size = batch_size
57
        self.bidirectional_merge = self.hparams["bidirectional_merge"]
lucas_miranda's avatar
lucas_miranda committed
58
        self.CONV_filters = self.hparams["units_conv"]
59
        self.DENSE_1 = int(self.hparams["units_gru"] / 2)
lucas_miranda's avatar
lucas_miranda committed
60
61
        self.DENSE_2 = self.hparams["units_dense2"]
        self.DROPOUT_RATE = self.hparams["dropout_rate"]
62
        self.ENCODING = encoding
63
64
        self.GRU_units_1 = self.hparams["units_gru"]
        self.GRU_units_2 = int(self.hparams["units_gru"] / 2)
lucas_miranda's avatar
lucas_miranda committed
65
66
        self.clipvalue = self.hparams["clipvalue"]
        self.dense_activation = self.hparams["dense_activation"]
67
        self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
lucas_miranda's avatar
lucas_miranda committed
68
        self.learn_rate = self.hparams["learning_rate"]
69
        self.gru_unroll = True
70
        self.compile = compile_model
71
        self.kl_annealing_mode = kl_annealing_mode
72
        self.kl_warmup = kl_warmup_epochs
73
        self.loss = loss
74
        self.mc_kl = montecarlo_kl
75
        self.mmd_annealing_mode = mmd_annealing_mode
76
        self.mmd_warmup = mmd_warmup_epochs
77
        self.number_of_components = number_of_components
78
        self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue)
79
        self.overlap_loss = overlap_loss
80
        self.next_sequence_prediction = next_sequence_prediction
lucas_miranda's avatar
lucas_miranda committed
81
        self.phenotype_prediction = phenotype_prediction
82
83
        self.rule_based_prediction = rule_based_prediction
        self.rule_based_features = rule_based_features
84
        self.prior = "standard_normal"
85
86
        self.reg_cat_clusters = reg_cat_clusters
        self.reg_cluster_variance = reg_cluster_variance
87

lucas_miranda's avatar
lucas_miranda committed
88
        assert (
89
            "ELBO" in self.loss or "MMD" in self.loss
lucas_miranda's avatar
lucas_miranda committed
90
91
92
93
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    @property
    def prior(self):
94
95
96
        """Property to set the value of the prior
        once the class is instanciated"""

lucas_miranda's avatar
lucas_miranda committed
97
98
99
100
101
        return self._prior

    def get_prior(self):
        """Sets the Variational Autoencoder prior distribution"""

102
        if self.prior == "standard_normal":
103
104
105

            self.prior = tfd.MixtureSameFamily(
                mixture_distribution=tfd.categorical.Categorical(
106
107
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
108
109
                components_distribution=tfd.MultivariateNormalDiag(
                    loc=tf.Variable(
lucas_miranda's avatar
lucas_miranda committed
110
                        he_uniform()(
111
                            [self.number_of_components, self.ENCODING],
112
113
                        ),
                        name="prior_means",
114
115
                    ),
                    scale_diag=tfp.util.TransformedVariable(
lucas_miranda's avatar
lucas_miranda committed
116
117
                        tf.ones([self.number_of_components, self.ENCODING])
                        / self.number_of_components,
118
                        tfb.Softplus(),
119
                        name="prior_scales",
120
121
                    ),
                ),
122
            )
123

124
        else:  # pragma: no cover
lucas_miranda's avatar
lucas_miranda committed
125
126
127
128
129
130
131
132
133
            raise NotImplementedError(
                "Gaussian Mixtures are currently the only supported prior"
            )

    @staticmethod
    def get_hparams(params: Dict) -> Dict:
        """Sets the default parameters for the model. Overwritable with a dictionary"""

        defaults = {
134
            "bidirectional_merge": "concat",
135
            "clipvalue": 0.75,
lucas_miranda's avatar
lucas_miranda committed
136
            "dense_activation": "relu",
137
            "dense_layers_per_branch": 1,
138
139
            "dropout_rate": 0.1,
            "learning_rate": 1e-4,
140
            "units_conv": 64,
141
            "units_dense2": 32,
142
            "units_gru": 128,
lucas_miranda's avatar
lucas_miranda committed
143
144
145
146
147
148
149
150
151
        }

        for k, v in params.items():
            defaults[k] = v

        return defaults

    def get_layers(self, input_shape):
        """Instanciate all layers in the model"""
152
153
154
155
156

        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
157
158
            strides=1,  # Increased strides to yield shorter sequences
            padding="valid",
lucas_miranda's avatar
lucas_miranda committed
159
            activation=self.dense_activation,
160
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
161
            use_bias=True,
162
163
        )
        Model_E1 = Bidirectional(
164
165
            GRU(
                self.GRU_units_1,
166
                activation="tanh",
167
                recurrent_activation="sigmoid",
168
                return_sequences=True,
169
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
170
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
171
                use_bias=True,
172
173
            ),
            merge_mode=self.bidirectional_merge,
174
175
        )
        Model_E2 = Bidirectional(
176
177
            GRU(
                self.GRU_units_2,
178
                activation="tanh",
179
                recurrent_activation="sigmoid",
180
                return_sequences=False,
181
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
182
                # kernel_constraint=UnitNorm(axis=0),
lucas_miranda's avatar
lucas_miranda committed
183
                use_bias=True,
184
185
            ),
            merge_mode=self.bidirectional_merge,
186
187
188
        )
        Model_E3 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
189
190
            activation=self.dense_activation,
            # kernel_constraint=UnitNorm(axis=0),
191
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
192
            use_bias=True,
193
        )
194

195
        seq_E = [
196
197
198
199
200
201
202
203
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                # kernel_constraint=UnitNorm(axis=0),
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
            for _ in range(self.dense_layers_per_branch)
204
        ]
205
        Model_E4 = []
206
207
        for layer in seq_E:
            Model_E4.append(layer)
208
            Model_E4.append(BatchNormalization())
209
210

        # Decoder layers
211
        seq_D = [
212
213
214
215
216
217
            Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
                use_bias=True,
            )
218
            for _ in range(self.dense_layers_per_branch)
219
        ]
220
        Model_D1 = []
221
222
        for layer in seq_D:
            Model_D1.append(layer)
223
224
            Model_D1.append(BatchNormalization())

225
        Model_D2 = Dense(
226
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
227
            activation=self.dense_activation,
228
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
229
            use_bias=True,
230
        )
lucas_miranda's avatar
lucas_miranda committed
231
        Model_D3 = RepeatVector(input_shape[1])
232
        Model_D4 = Bidirectional(
233
234
            GRU(
                self.GRU_units_2,
235
                activation="tanh",
236
                recurrent_activation="sigmoid",
237
                return_sequences=True,
238
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
239
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
240
                use_bias=True,
241
242
            ),
            merge_mode=self.bidirectional_merge,
243
244
        )
        Model_D5 = Bidirectional(
245
246
            GRU(
                self.GRU_units_1,
247
                activation="tanh",
248
                recurrent_activation="sigmoid",
249
                return_sequences=True,
250
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
251
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
252
                use_bias=True,
253
254
            ),
            merge_mode=self.bidirectional_merge,
255
        )
256
257
258
259
260
261
262
263
264
        Model_D6 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="same",
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
            use_bias=True,
        )
lucas_miranda's avatar
lucas_miranda committed
265
266

        # Predictor layers
lucas_miranda's avatar
lucas_miranda committed
267
268
        Model_P1 = Dense(
            self.DENSE_1,
lucas_miranda's avatar
lucas_miranda committed
269
            activation=self.dense_activation,
lucas_miranda's avatar
lucas_miranda committed
270
            kernel_initializer=he_uniform(),
lucas_miranda's avatar
lucas_miranda committed
271
            use_bias=True,
lucas_miranda's avatar
lucas_miranda committed
272
273
        )
        Model_P2 = Bidirectional(
274
275
            GRU(
                self.GRU_units_1,
lucas_miranda's avatar
lucas_miranda committed
276
277
278
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
279
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
280
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
281
                use_bias=True,
282
283
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
284
285
        )
        Model_P3 = Bidirectional(
286
287
            GRU(
                self.GRU_units_1,
lucas_miranda's avatar
lucas_miranda committed
288
289
290
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
291
                unroll=self.gru_unroll,
lucas_miranda's avatar
lucas_miranda committed
292
                # kernel_constraint=UnitNorm(axis=1),
lucas_miranda's avatar
lucas_miranda committed
293
                use_bias=True,
294
295
            ),
            merge_mode=self.bidirectional_merge,
lucas_miranda's avatar
lucas_miranda committed
296
        )
297
298
299
300
301
302
303
304
305
        Model_P4 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="same",
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
            use_bias=True,
        )
306

307
        # Phenotype classification layer
308
        Model_PC1 = Dense(
lucas_miranda's avatar
lucas_miranda committed
309
310
311
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
312
        )
lucas_miranda's avatar
lucas_miranda committed
313

314
315
316
317
318
319
320
        # Rule based trait classification layer
        Model_RC1 = Dense(
            self.number_of_components,
            activation=self.dense_activation,
            kernel_initializer=he_uniform(),
        )

lucas_miranda's avatar
lucas_miranda committed
321
322
323
324
325
326
327
328
329
330
331
        return (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
332
            Model_D6,
lucas_miranda's avatar
lucas_miranda committed
333
334
335
            Model_P1,
            Model_P2,
            Model_P3,
336
            Model_P4,
lucas_miranda's avatar
lucas_miranda committed
337
            Model_PC1,
338
            Model_RC1,
lucas_miranda's avatar
lucas_miranda committed
339
340
341
        )

    def build(self, input_shape: Tuple):
342
        """Builds the tf.keras model"""
lucas_miranda's avatar
lucas_miranda committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

        # Instanciate prior
        self.get_prior()

        # Get model layers
        (
            Model_E0,
            Model_E1,
            Model_E2,
            Model_E3,
            Model_E4,
            Model_D1,
            Model_D2,
            Model_D3,
            Model_D4,
            Model_D5,
359
            Model_D6,
lucas_miranda's avatar
lucas_miranda committed
360
361
362
            Model_P1,
            Model_P2,
            Model_P3,
363
            Model_P4,
lucas_miranda's avatar
lucas_miranda committed
364
            Model_PC1,
365
            Model_RC1,
lucas_miranda's avatar
lucas_miranda committed
366
367
        ) = self.get_layers(input_shape)

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        # Define and instantiate encoder
        x = Input(shape=input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Sequential(Model_E4)(encoder)

        # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
        z_cat = Dense(
            self.number_of_components,
            name="cluster_assignment",
            activation="softmax",
            activity_regularizer=(
                tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
                if self.reg_cat_clusters
                else None
            ),
        )(encoder)
392

393
394
395
396
397
398
399
        z_gauss_mean = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            )
            // 2,
            name="cluster_means",
            activation=None,
400
            activity_regularizer=(tf.keras.regularizers.l1(10e-5)),
lucas_miranda's avatar
lucas_miranda committed
401
            kernel_initializer=he_uniform(),
402
403
404
405
406
407
408
409
410
411
412
413
        )(encoder)

        z_gauss_var = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            )
            // 2,
            name="cluster_variances",
            activation=None,
            activity_regularizer=(
                tf.keras.regularizers.l2(0.01) if self.reg_cluster_variance else None
            ),
414
            kernel_initializer=random_uniform(),
415
        )(encoder)
416

417
        z_gauss = tf.keras.layers.concatenate([z_gauss_mean, z_gauss_var], axis=1)
418

419
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
420

421
422
423
424
        z = tfpl.DistributionLambda(
            make_distribution_fn=lambda gauss: tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(
                    probs=gauss[0],
425
                ),
426
427
428
429
                components=[
                    tfd.Independent(
                        tfd.Normal(
                            loc=gauss[1][..., : self.ENCODING, k],
430
                            scale=1e-3 + softplus(gauss[1][..., self.ENCODING :, k]),
431
432
433
434
435
436
437
438
439
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
            convert_to_tensor_fn="sample",
            name="encoding_distribution",
        )([z_cat, z_gauss])
440

441
        posterior = Model(x, z, name="SEQ_2_SEQ_trained_distribution")
442

443
444
445
446
447
448
        # Define and control custom loss functions
        if "ELBO" in self.loss:
            kl_warm_up_iters = tf.cast(
                self.kl_warmup * (input_shape[0] // self.batch_size + 1),
                tf.int64,
            )
449

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            # noinspection PyCallingNonCallable
            z = deepof.model_utils.KLDivergenceLayer(
                distribution_b=self.prior,
                test_points_fn=lambda q: q.sample(self.mc_kl),
                test_points_reduce_axis=0,
                iters=self.optimizer.iterations,
                warm_up_iters=kl_warm_up_iters,
                annealing_mode=self.kl_annealing_mode,
            )(z)

        if "MMD" in self.loss:
            mmd_warm_up_iters = tf.cast(
                self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
                tf.int64,
            )
465

466
467
468
469
470
471
472
473
474
475
476
            z = deepof.model_utils.MMDiscrepancyLayer(
                batch_size=self.batch_size,
                prior=self.prior,
                iters=self.optimizer.iterations,
                warm_up_iters=mmd_warm_up_iters,
                annealing_mode=self.mmd_annealing_mode,
            )(z)

        # Dummy layer with no parameters, to retrieve the previous tensor
        z = tf.keras.layers.Lambda(lambda t: t, name="latent_distribution")(z)

lucas_miranda's avatar
lucas_miranda committed
477
        if self.number_of_components > 1 and self.overlap_loss:
478
            z = deepof.model_utils.ClusterOverlap(
479
480
481
                batch_size=self.batch_size,
                encoding_dim=self.ENCODING,
                k=self.number_of_components,
482
483
                loss_weight=self.overlap_loss,
            )([z, z_cat])
484

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        # Define and instantiate generator
        g = Input(shape=self.ENCODING)
        generator = Sequential(Model_D1)(g)
        generator = Model_D2(generator)
        generator = BatchNormalization()(generator)
        generator = Model_D3(generator)
        generator = Model_D4(generator)
        generator = BatchNormalization()(generator)
        generator = Model_D5(generator)
        generator = BatchNormalization()(generator)
        generator = Model_D6(generator)
        generator = BatchNormalization()(generator)
        x_decoded_mean = Dense(
            tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
        )(generator)
        x_decoded_var = tf.keras.activations.softplus(
            Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(generator)
        )
503
        x_decoded_var = tf.keras.layers.Lambda(lambda v: 1e-3 + v)(x_decoded_var)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        x_decoded = tf.keras.layers.concatenate(
            [x_decoded_mean, x_decoded_var], axis=-1
        )
        x_decoded_mean = tfpl.IndependentNormal(
            event_shape=input_shape[2:],
            convert_to_tensor_fn=tfp.distributions.Distribution.mean,
            name="vae_reconstruction",
        )(x_decoded)

        # define individual branches as models
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
        generator = Model(g, x_decoded_mean, name="vae_reconstruction")

        def log_loss(x_true, p_x_q_given_z):
            """Computes the negative log likelihood of the data given
            the output distribution"""
            return -tf.reduce_sum(p_x_q_given_z.log_prob(x_true))

        model_outs = [generator(encoder.outputs)]
        model_losses = [log_loss]
        model_metrics = {"vae_reconstruction": ["mae", "mse"]}
        loss_weights = [1.0]

        if self.next_sequence_prediction > 0:
            # Define and instantiate predictor
            predictor = Dense(
                self.DENSE_2,
                activation=self.dense_activation,
                kernel_initializer=he_uniform(),
            )(z)
            predictor = BatchNormalization()(predictor)
            predictor = Model_P1(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = RepeatVector(input_shape[1])(predictor)
            predictor = Model_P2(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = Model_P3(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = Model_P4(predictor)
            x_predicted_mean = Dense(
544
                tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
545
546
547
548
549
550
            )(predictor)
            x_predicted_var = tf.keras.activations.softplus(
                Dense(tfpl.IndependentNormal.params_size(input_shape[2:]) // 2)(
                    predictor
                )
            )
551
            x_predicted_var = tf.keras.layers.Lambda(lambda v: 1e-3 + v)(
552
                x_predicted_var
553
554
            )
            x_decoded = tf.keras.layers.concatenate(
555
                [x_predicted_mean, x_predicted_var], axis=-1
556
            )
557
            x_predicted_mean = tfpl.IndependentNormal(
558
559
                event_shape=input_shape[2:],
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
560
                name="vae_prediction",
561
            )(x_decoded)
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
            model_outs.append(x_predicted_mean)
            model_losses.append(log_loss)
            model_metrics["vae_prediction"] = ["mae", "mse"]
            loss_weights.append(self.next_sequence_prediction)

        if self.phenotype_prediction > 0:
            pheno_pred = Model_PC1(z)
            pheno_pred = Dense(tfpl.IndependentBernoulli.params_size(1))(pheno_pred)
            pheno_pred = tfpl.IndependentBernoulli(
                event_shape=1,
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="phenotype_prediction",
            )(pheno_pred)

            model_outs.append(pheno_pred)
            model_losses.append(log_loss)
            model_metrics["phenotype_prediction"] = ["AUC", "accuracy"]
            loss_weights.append(self.phenotype_prediction)

        if self.rule_based_prediction > 0:
            rule_pred = Model_RC1(z)

            rule_pred = Dense(
                tfpl.IndependentBernoulli.params_size(self.rule_based_features)
            )(rule_pred)
            rule_pred = tfpl.IndependentBernoulli(
                event_shape=self.rule_based_features,
                convert_to_tensor_fn=tfp.distributions.Distribution.mean,
                name="rule_based_prediction",
            )(rule_pred)

            model_outs.append(rule_pred)
            model_losses.append(log_loss)
            model_metrics["rule_based_prediction"] = [
                "mae",
                "mse",
            ]
            loss_weights.append(self.rule_based_prediction)
601

lucas_miranda's avatar
lucas_miranda committed
602
        # define grouper and end-to-end autoencoder model
603
        grouper = Model(encoder.inputs, z_cat, name="Deep_Gaussian_Mixture_clustering")
604
        gmvaep = Model(
605
            inputs=encoder.inputs,
606
607
608
            outputs=model_outs,
            name="SEQ_2_SEQ_GMVAE",
        )
609

610
611
612
        if self.compile:
            gmvaep.compile(
                loss=model_losses,
613
                optimizer=self.optimizer,
614
                metrics=model_metrics,
615
616
                loss_weights=loss_weights,
            )
617

618
619
        gmvaep.build(input_shape)

620
        return (
621
            encoder,
622
623
624
            generator,
            grouper,
            gmvaep,
625
626
            self.prior,
            posterior,
627
        )
lucas_miranda's avatar
lucas_miranda committed
628

lucas_miranda's avatar
lucas_miranda committed
629
630
631
632
    @prior.setter
    def prior(self, value):
        self._prior = value

633

634
# TODO:
635
#       - Check usefulness of stateful sequential layers! (stateful=True in the GRUs)
636
#       - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
lucas_miranda's avatar
lucas_miranda committed
637
#       - Explore expanding the event dims of the final reconstruction layer
638
#       - Think about gradient penalty to avoid mode collapse (as in WGAN-GP)
639
#       - Think about using spectral normalization
640
641
#       - REVISIT DROPOUT - CAN HELP WITH TRAINING STABILIZATION
#       - Decrease learning rate!
642
#       - Implement residual blocks!