models.py 29.8 KB
Newer Older
1
2
# @author lucasmiranda42

3
from tensorflow.keras import backend as K
4
from tensorflow.keras import Input, Model, Sequential
5
from tensorflow.keras.callbacks import LambdaCallback
6
from tensorflow.keras.constraints import UnitNorm
7
from tensorflow.keras.initializers import he_uniform, Orthogonal
8
9
from tensorflow.keras.layers import BatchNormalization, Bidirectional, Dense
from tensorflow.keras.layers import Dropout, Lambda, LSTM
10
from tensorflow.keras.layers import RepeatVector, TimeDistributed
11
from tensorflow.keras.losses import Huber
12
from tensorflow.keras.optimizers import Adam
13
from source.model_utils import *
14
import tensorflow as tf
15
16
17
18
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers
19
20
21


class SEQ_2_SEQ_AE:
22
23
24
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
25
26
27
28
29
30
31
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
51
            kernel_initializer=he_uniform(),
52
        )
53
        Model_E1 = Bidirectional(
54
            LSTM(
55
56
57
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
58
                kernel_constraint=UnitNorm(axis=0),
59
60
            )
        )
61
        Model_E2 = Bidirectional(
62
            LSTM(
63
64
65
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
66
                kernel_constraint=UnitNorm(axis=0),
67
68
            )
        )
69
        Model_E3 = Dense(
70
71
72
73
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
74
75
        )
        Model_E4 = Dense(
76
77
78
79
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
80
        )
81
82
83
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
84
            kernel_constraint=UnitNorm(axis=1),
85
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
86
            kernel_initializer=Orthogonal(),
87
88
89
        )

        # Decoder layers
90
        Model_D0 = DenseTranspose(
91
            Model_E5, activation="relu", output_dim=self.ENCODING,
92
        )
93
94
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
95
        Model_D3 = RepeatVector(self.input_shape[1])
96
        Model_D4 = Bidirectional(
97
            LSTM(
98
99
100
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
101
                kernel_constraint=UnitNorm(axis=1),
102
103
            )
        )
104
        Model_D5 = Bidirectional(
105
            LSTM(
106
107
108
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
109
                kernel_constraint=UnitNorm(axis=1),
110
111
112
            )
        )

113
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
114
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
115
        encoder.add(Input(shape=self.input_shape[1:]))
116
        encoder.add(Model_E0)
117
        encoder.add(BatchNormalization())
118
        encoder.add(Model_E1)
119
        encoder.add(BatchNormalization())
120
        encoder.add(Model_E2)
121
        encoder.add(BatchNormalization())
122
        encoder.add(Model_E3)
123
        encoder.add(BatchNormalization())
124
125
        encoder.add(Dropout(self.DROPOUT_RATE))
        encoder.add(Model_E4)
126
        encoder.add(BatchNormalization())
127
128
        encoder.add(Model_E5)

129
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
130
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
131
        decoder.add(Model_D0)
132
        decoder.add(BatchNormalization())
133
        decoder.add(Model_D1)
134
        decoder.add(BatchNormalization())
135
        decoder.add(Model_D2)
136
        decoder.add(BatchNormalization())
137
        decoder.add(Model_D3)
138
        decoder.add(Model_D4)
139
        decoder.add(BatchNormalization())
140
141
142
        decoder.add(Model_D5)
        decoder.add(TimeDistributed(Dense(self.input_shape[2])))

lucas_miranda's avatar
lucas_miranda committed
143
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
144
145

        model.compile(
146
            loss=Huber(reduction="sum", delta=100.0),
147
            optimizer=Adam(lr=self.learn_rate, clipvalue=0.5,),
148
149
150
            metrics=["mae"],
        )

lucas_miranda's avatar
lucas_miranda committed
151
        return encoder, decoder, model
152
153
154


class SEQ_2_SEQ_VAE:
155
156
157
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
158
159
160
161
162
163
164
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
165
        loss="ELBO+MMD",
166
167
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
168
        prior="standard_normal",
169
170
171
172
173
174
175
176
177
178
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
179
        self.loss = loss
180
        self.prior = prior
181
182
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
183

184
185
186
187
188
189
        if self.prior == "standard_normal":
            self.prior = tfd.Independent(
                tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
                reinterpreted_batch_ndims=1,
            )

190
191
192
193
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

194
195
196
    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
197
            filters=self.CONV_filters,
198
199
200
201
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
202
            kernel_initializer=he_uniform(),
203
        )
204
        Model_E1 = Bidirectional(
205
            LSTM(
206
                self.LSTM_units_1,
207
208
                activation="tanh",
                return_sequences=True,
209
                kernel_constraint=UnitNorm(axis=0),
210
211
            )
        )
212
        Model_E2 = Bidirectional(
213
            LSTM(
214
                self.LSTM_units_2,
215
216
                activation="tanh",
                return_sequences=False,
217
                kernel_constraint=UnitNorm(axis=0),
218
219
            )
        )
220
        Model_E3 = Dense(
221
222
223
224
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
225
226
        )
        Model_E4 = Dense(
227
228
229
230
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
231
        )
232
        Model_E5 = Dense(
233
            self.ENCODING,
234
            activation="relu",
235
            kernel_constraint=UnitNorm(axis=1),
236
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
237
            kernel_initializer=Orthogonal(),
238
239
240
        )

        # Decoder layers
241
242
243
244
245
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
        Model_B5 = BatchNormalization()
246
        Model_D0 = DenseTranspose(
247
            Model_E5, activation="relu", output_dim=self.ENCODING,
248
        )
249
250
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
lucas_miranda's avatar
lucas_miranda committed
251
        Model_D3 = RepeatVector(self.input_shape[1])
252
        Model_D4 = Bidirectional(
253
            LSTM(
254
                self.LSTM_units_1,
255
256
                activation="tanh",
                return_sequences=True,
257
                kernel_constraint=UnitNorm(axis=1),
258
259
            )
        )
260
        Model_D5 = Bidirectional(
261
            LSTM(
262
                self.LSTM_units_1,
263
264
                activation="sigmoid",
                return_sequences=True,
265
                kernel_constraint=UnitNorm(axis=1),
266
267
268
            )
        )

269
        # Define and instantiate encoder
270
        x = Input(shape=self.input_shape[1:])
271
        encoder = Model_E0(x)
272
        encoder = BatchNormalization()(encoder)
273
        encoder = Model_E1(encoder)
274
        encoder = BatchNormalization()(encoder)
275
        encoder = Model_E2(encoder)
276
        encoder = BatchNormalization()(encoder)
277
        encoder = Model_E3(encoder)
278
        encoder = BatchNormalization()(encoder)
279
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
280
        encoder = Model_E4(encoder)
281
        encoder = BatchNormalization()(encoder)
282
283
        encoder = Model_E5(encoder)

284
285
286
287
288
289
        # z_mean = Dense(self.ENCODING)(encoder)
        # z_log_sigma = Dense(self.ENCODING)(encoder)

        encoder = Dense(
            tfpl.MultivariateNormalTriL.params_size(self.ENCODING), activation=None
        )(encoder)
290

291
        # Define and control custom loss functions
292
        kl_warmup_callback = False
293
        if "ELBO" in self.loss:
294

295
            kl_beta = K.variable(1.0, name="kl_beta")
296
            kl_beta._trainable = False
297
298
            if self.kl_warmup:

299
300
301
302
303
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
                )
304

305
            # z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
306

307
308
309
310
311
312
313
314
315
        # z = Lambda(sampling)([z_mean, z_log_sigma])
        z = tfpl.MultivariateNormalTriL(
            self.ENCODING,
            activity_regularizer=(
                tfpl.KLDivergenceRegularizer(self.prior, weight=kl_beta)
                if "ELBO" in self.loss
                else None
            ),
        )(encoder)
316

317
        mmd_warmup_callback = False
318
        if "MMD" in self.loss:
319

320
            mmd_beta = K.variable(1.0, name="mmd_beta")
321
            mmd_beta._trainable = False
322
            if self.mmd_warmup:
323

324
325
326
327
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
328
                )
329
330

            z = MMDiscrepancyLayer(beta=mmd_beta)(z)
331

332
        # Define and instantiate generator
lucas_miranda's avatar
lucas_miranda committed
333
        generator = Model_D0(z)
334
        generator = Model_B1(generator)
lucas_miranda's avatar
lucas_miranda committed
335
        generator = Model_D1(generator)
336
        generator = Model_B2(generator)
lucas_miranda's avatar
lucas_miranda committed
337
        generator = Model_D2(generator)
338
        generator = Model_B3(generator)
lucas_miranda's avatar
lucas_miranda committed
339
340
        generator = Model_D3(generator)
        generator = Model_D4(generator)
341
        generator = Model_B4(generator)
lucas_miranda's avatar
lucas_miranda committed
342
        generator = Model_D5(generator)
343
        generator = Model_B5(generator)
lucas_miranda's avatar
lucas_miranda committed
344
        x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
345

346
        # end-to-end autoencoder
347
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
348
        vae = Model(x, x_decoded_mean, name="SEQ_2_SEQ_VAE")
lucas_miranda's avatar
lucas_miranda committed
349

350
351
352
        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
        _generator = Model_D0(g)
353
        _generator = Model_B1(_generator)
354
        _generator = Model_D1(_generator)
355
        _generator = Model_B2(_generator)
356
        _generator = Model_D2(_generator)
357
        _generator = Model_B3(_generator)
358
359
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
360
        _generator = Model_B4(_generator)
361
        _generator = Model_D5(_generator)
362
        _generator = Model_B5(_generator)
363
364
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
365

366
367
368
        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)
369
370

        vae.compile(
371
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
372
373
        )

374
        return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback
375
376


377
378
class SEQ_2_SEQ_VAEP:
    def __init__(
379
380
381
382
383
384
385
386
387
388
        self,
        input_shape,
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
        loss="ELBO+MMD",
389
390
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
391
392
393
394
395
396
397
398
399
400
401
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
        self.loss = loss
402
403
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
404

405
406
407
408
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
            kernel_initializer=he_uniform(),
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E4 = Dense(
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
            kernel_constraint=UnitNorm(axis=1),
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
            kernel_initializer=Orthogonal(),
        )

        # Decoder layers
456
457
458
459
460
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
        Model_B5 = BatchNormalization()
461
462
463
        Model_D0 = DenseTranspose(
            Model_E5, activation="relu", output_dim=self.ENCODING,
        )
464
465
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        Model_D3 = RepeatVector(self.input_shape[1])
        Model_D4 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )

484
        # Define and instantiate encoder
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        x = Input(shape=self.input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E5(encoder)

        z_mean = Dense(self.ENCODING)(encoder)
        z_log_sigma = Dense(self.ENCODING)(encoder)

502
        # Define and control custom loss functions
503
        kl_warmup_callback = False
504
        if "ELBO" in self.loss:
505

506
            kl_beta = K.variable(1.0, name="kl_beta")
507
            kl_beta._trainable = False
508
            if self.kl_warmup:
509
510
511
512
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
513
                )
514
515

            z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
516
517
518

        z = Lambda(sampling)([z_mean, z_log_sigma])

519
        mmd_warmup_callback = False
520
        if "MMD" in self.loss:
521

522
            mmd_beta = K.variable(1.0, name="mmd_beta")
523
            mmd_beta._trainable = False
524
525
526
527
528
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
529
530
531
                )

            z = MMDiscrepancyLayer(beta=mmd_beta)(z)
532

533
        # Define and instantiate generator
534
        generator = Model_D0(z)
535
        generator = Model_B1(generator)
536
        generator = Model_D1(generator)
537
        generator = Model_B2(generator)
538
        generator = Model_D2(generator)
539
        generator = Model_B3(generator)
540
541
        generator = Model_D3(generator)
        generator = Model_D4(generator)
542
        generator = Model_B4(generator)
543
        generator = Model_D5(generator)
544
        generator = Model_B5(generator)
545
546
547
        x_decoded_mean = TimeDistributed(
            Dense(self.input_shape[2]), name="vaep_reconstruction"
        )(generator)
548

549
        # Define and instantiate predictor
550
551
552
        predictor = Dense(
            self.ENCODING, activation="relu", kernel_initializer=he_uniform()
        )(z)
553
        predictor = BatchNormalization()(predictor)
554
555
556
        predictor = Dense(
            self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
557
        predictor = BatchNormalization()(predictor)
558
559
560
        predictor = Dense(
            self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
561
        predictor = BatchNormalization()(predictor)
562
        predictor = RepeatVector(self.input_shape[1])(predictor)
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
580
        predictor = BatchNormalization()(predictor)
581
582
583
        x_predicted_mean = TimeDistributed(
            Dense(self.input_shape[2]), name="vaep_prediction"
        )(predictor)
584

585
586
        # end-to-end autoencoder
        encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
587
588
589
        vaep = Model(
            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
        )
590
591
592
593

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
        _generator = Model_D0(g)
594
        _generator = Model_B1(_generator)
595
        _generator = Model_D1(_generator)
596
        _generator = Model_B2(_generator)
597
        _generator = Model_D2(_generator)
598
        _generator = Model_B3(_generator)
599
600
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
601
        _generator = Model_B4(_generator)
602
        _generator = Model_D5(_generator)
603
        _generator = Model_B5(_generator)
604
605
606
607
608
609
610
611
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)

        vaep.compile(
612
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
613
614
        )

615
        return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback
616
617


618
619
class SEQ_2_SEQ_MMVAEP:
    def __init__(
620
621
622
623
624
625
626
627
628
629
630
631
632
        self,
        input_shape,
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
        loss="ELBO+MMD",
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
        number_of_components=1,
633
634
635
636
637
638
639
640
641
642
643
644
645
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
        self.loss = loss
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
646
        self.number_of_components = number_of_components
647
648

        assert (
649
650
651
652
653
            self.number_of_components > 0
        ), "The number of components must be an integer greater than zero"

        assert (
            "ELBO" in self.loss or "MMD" in self.loss
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
            kernel_initializer=he_uniform(),
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E4 = Dense(
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
            kernel_constraint=UnitNorm(axis=1),
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
            kernel_initializer=Orthogonal(),
        )

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
        Model_B5 = BatchNormalization()
        Model_D0 = DenseTranspose(
            Model_E5, activation="relu", output_dim=self.ENCODING,
        )
711
712
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        Model_D3 = RepeatVector(self.input_shape[1])
        Model_D4 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )

        # Define and instantiate encoder
        x = Input(shape=self.input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E5(encoder)

746
747
748
        # Categorical prior on mixture of Gaussians
        categories = Dense(self.number_of_components, activation="softmax")

749
750
751
        # Define mean and log_sigma as lists of vectors with an item per prior component
        z_mean = []
        z_log_sigma = []
752
        for i in range(self.number_of_components):
753
754
755
756
757
            z_mean.append(
                Dense(self.ENCODING, name="{}_gaussian_mean".format(i + 1))(encoder)
            )
            z_log_sigma.append(
                Dense(self.ENCODING, name="{}_gaussian_sigma".format(i + 1))(encoder)
758
            )
759
760
761
762
763
764
765
766
767
768
769
770
771
772

        # Define and control custom loss functions
        kl_warmup_callback = False
        if "ELBO" in self.loss:

            kl_beta = K.variable(1.0, name="kl_beta")
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
                )

773
774
775
            z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)(
                [z_mean[0], z_log_sigma[0]]
            )
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871

        z = Lambda(sampling)([z_mean, z_log_sigma])

        mmd_warmup_callback = False
        if "MMD" in self.loss:

            mmd_beta = K.variable(1.0, name="mmd_beta")
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
                )

            z = MMDiscrepancyLayer(beta=mmd_beta)(z)

        # Define and instantiate generator
        generator = Model_D0(z)
        generator = Model_B1(generator)
        generator = Model_D1(generator)
        generator = Model_B2(generator)
        generator = Model_D2(generator)
        generator = Model_B3(generator)
        generator = Model_D3(generator)
        generator = Model_D4(generator)
        generator = Model_B4(generator)
        generator = Model_D5(generator)
        generator = Model_B5(generator)
        x_decoded_mean = TimeDistributed(
            Dense(self.input_shape[2]), name="gmvaep_reconstruction"
        )(generator)

        # Define and instantiate predictor
        predictor = Dense(
            self.ENCODING, activation="relu", kernel_initializer=he_uniform()
        )(z)
        predictor = BatchNormalization()(predictor)
        predictor = Dense(
            self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = Dense(
            self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = RepeatVector(self.input_shape[1])(predictor)
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
        predictor = BatchNormalization()(predictor)
        x_predicted_mean = TimeDistributed(
            Dense(self.input_shape[2]), name="gmvaep_prediction"
        )(predictor)

        # end-to-end autoencoder
        encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
        gmvaep = Model(
            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
        )

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
        _generator = Model_D0(g)
        _generator = Model_B1(_generator)
        _generator = Model_D1(_generator)
        _generator = Model_B2(_generator)
        _generator = Model_D2(_generator)
        _generator = Model_B3(_generator)
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
        _generator = Model_B4(_generator)
        _generator = Model_D5(_generator)
        _generator = Model_B5(_generator)
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)

        gmvaep.compile(
872
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
873
874
875
        )

        return encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback
lucas_miranda's avatar
lucas_miranda committed
876

877

878
879
# TODO:
#       - Gaussian Mixture + Categorical priors -> Deep Clustering
880
#       - MCMC sampling (n>1)
881
882
#
# TODO (in the non-immediate future):
883
884
885
#       - free bits paper
#       - Attention mechanism for encoder / decoder (does it make sense?)
#       - Transformer encoder/decoder (does it make sense?)