models.py 30.2 KB
Newer Older
1
2
# @author lucasmiranda42

3
from tensorflow.keras import backend as K
4
from tensorflow.keras import Input, Model, Sequential
5
from tensorflow.keras.callbacks import LambdaCallback
6
from tensorflow.keras.constraints import UnitNorm
7
from tensorflow.keras.initializers import he_uniform, Orthogonal
8
from tensorflow.keras.layers import BatchNormalization, Bidirectional, Dense
9
from tensorflow.keras.layers import Dropout, LSTM
10
from tensorflow.keras.layers import RepeatVector, TimeDistributed
11
from tensorflow.keras.losses import Huber
12
from tensorflow.keras.optimizers import Adam
13
from source.model_utils import *
14
import tensorflow as tf
15
16
17
18
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers
19
20
21


class SEQ_2_SEQ_AE:
22
23
24
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
25
26
27
28
29
30
31
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
51
            kernel_initializer=he_uniform(),
52
        )
53
        Model_E1 = Bidirectional(
54
            LSTM(
55
56
57
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
58
                kernel_constraint=UnitNorm(axis=0),
59
60
            )
        )
61
        Model_E2 = Bidirectional(
62
            LSTM(
63
64
65
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
66
                kernel_constraint=UnitNorm(axis=0),
67
68
            )
        )
69
        Model_E3 = Dense(
70
71
72
73
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
74
75
        )
        Model_E4 = Dense(
76
77
78
79
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
80
        )
81
82
83
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
84
            kernel_constraint=UnitNorm(axis=1),
85
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
86
            kernel_initializer=Orthogonal(),
87
88
89
        )

        # Decoder layers
90
        Model_D0 = DenseTranspose(
91
            Model_E5, activation="relu", output_dim=self.ENCODING,
92
        )
93
94
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
95
        Model_D3 = RepeatVector(self.input_shape[1])
96
        Model_D4 = Bidirectional(
97
            LSTM(
98
99
100
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
101
                kernel_constraint=UnitNorm(axis=1),
102
103
            )
        )
104
        Model_D5 = Bidirectional(
105
            LSTM(
106
107
108
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
109
                kernel_constraint=UnitNorm(axis=1),
110
111
112
            )
        )

113
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
114
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
115
        encoder.add(Input(shape=self.input_shape[1:]))
116
        encoder.add(Model_E0)
117
        encoder.add(BatchNormalization())
118
        encoder.add(Model_E1)
119
        encoder.add(BatchNormalization())
120
        encoder.add(Model_E2)
121
        encoder.add(BatchNormalization())
122
        encoder.add(Model_E3)
123
        encoder.add(BatchNormalization())
124
125
        encoder.add(Dropout(self.DROPOUT_RATE))
        encoder.add(Model_E4)
126
        encoder.add(BatchNormalization())
127
128
        encoder.add(Model_E5)

129
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
130
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
131
        decoder.add(Model_D0)
132
        decoder.add(BatchNormalization())
133
        decoder.add(Model_D1)
134
        decoder.add(BatchNormalization())
135
        decoder.add(Model_D2)
136
        decoder.add(BatchNormalization())
137
        decoder.add(Model_D3)
138
        decoder.add(Model_D4)
139
        decoder.add(BatchNormalization())
140
141
142
        decoder.add(Model_D5)
        decoder.add(TimeDistributed(Dense(self.input_shape[2])))

lucas_miranda's avatar
lucas_miranda committed
143
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
144
145

        model.compile(
146
            loss=Huber(reduction="sum", delta=100.0),
147
            optimizer=Adam(lr=self.learn_rate, clipvalue=0.5,),
148
149
150
            metrics=["mae"],
        )

lucas_miranda's avatar
lucas_miranda committed
151
        return encoder, decoder, model
152
153
154


class SEQ_2_SEQ_VAE:
155
156
157
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
158
159
160
161
162
163
164
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
165
        loss="ELBO+MMD",
166
167
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
168
        prior="standard_normal",
169
170
171
172
173
174
175
176
177
178
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
179
        self.loss = loss
180
        self.prior = prior
181
182
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
183

184
185
186
187
188
189
        if self.prior == "standard_normal":
            self.prior = tfd.Independent(
                tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
                reinterpreted_batch_ndims=1,
            )

190
191
192
193
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

194
195
196
    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
197
            filters=self.CONV_filters,
198
199
200
201
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
202
            kernel_initializer=he_uniform(),
203
        )
204
        Model_E1 = Bidirectional(
205
            LSTM(
206
                self.LSTM_units_1,
207
208
                activation="tanh",
                return_sequences=True,
209
                kernel_constraint=UnitNorm(axis=0),
210
211
            )
        )
212
        Model_E2 = Bidirectional(
213
            LSTM(
214
                self.LSTM_units_2,
215
216
                activation="tanh",
                return_sequences=False,
217
                kernel_constraint=UnitNorm(axis=0),
218
219
            )
        )
220
        Model_E3 = Dense(
221
222
223
224
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
225
226
        )
        Model_E4 = Dense(
227
228
229
230
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
231
        )
232
        Model_E5 = Dense(
233
            self.ENCODING,
234
            activation="relu",
235
            kernel_constraint=UnitNorm(axis=1),
236
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
237
            kernel_initializer=Orthogonal(),
238
239
240
        )

        # Decoder layers
241
242
243
244
245
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
        Model_B5 = BatchNormalization()
246
        Model_D0 = DenseTranspose(
247
            Model_E5, activation="relu", output_dim=self.ENCODING,
248
        )
249
250
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
lucas_miranda's avatar
lucas_miranda committed
251
        Model_D3 = RepeatVector(self.input_shape[1])
252
        Model_D4 = Bidirectional(
253
            LSTM(
254
                self.LSTM_units_1,
255
256
                activation="tanh",
                return_sequences=True,
257
                kernel_constraint=UnitNorm(axis=1),
258
259
            )
        )
260
        Model_D5 = Bidirectional(
261
            LSTM(
262
                self.LSTM_units_1,
263
264
                activation="sigmoid",
                return_sequences=True,
265
                kernel_constraint=UnitNorm(axis=1),
266
267
268
            )
        )

269
        # Define and instantiate encoder
270
        x = Input(shape=self.input_shape[1:])
271
        encoder = Model_E0(x)
272
        encoder = BatchNormalization()(encoder)
273
        encoder = Model_E1(encoder)
274
        encoder = BatchNormalization()(encoder)
275
        encoder = Model_E2(encoder)
276
        encoder = BatchNormalization()(encoder)
277
        encoder = Model_E3(encoder)
278
        encoder = BatchNormalization()(encoder)
279
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
280
        encoder = Model_E4(encoder)
281
        encoder = BatchNormalization()(encoder)
282
283
        encoder = Model_E5(encoder)

284
        encoder = Dense(
285
            tfpl.IndependentNormal.params_size(self.ENCODING), activation=None
286
        )(encoder)
287

288
        # Define and control custom loss functions
289
        kl_warmup_callback = False
290
        if "ELBO" in self.loss:
291

292
            kl_beta = K.variable(1.0, name="kl_beta")
293
            kl_beta._trainable = False
294
295
            if self.kl_warmup:

296
297
298
299
300
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
                )
301

302
        z = tfpl.IndependentNormal(self.ENCODING)(encoder)
303

304
305
        if "ELBO" in self.loss:
            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
306

307
        mmd_warmup_callback = False
308
        if "MMD" in self.loss:
309

310
            mmd_beta = K.variable(1.0, name="mmd_beta")
311
            mmd_beta._trainable = False
312
            if self.mmd_warmup:
313

314
315
316
317
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
318
                )
319

320
            z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
321

322
        # Define and instantiate generator
lucas_miranda's avatar
lucas_miranda committed
323
        generator = Model_D0(z)
324
        generator = Model_B1(generator)
lucas_miranda's avatar
lucas_miranda committed
325
        generator = Model_D1(generator)
326
        generator = Model_B2(generator)
lucas_miranda's avatar
lucas_miranda committed
327
        generator = Model_D2(generator)
328
        generator = Model_B3(generator)
lucas_miranda's avatar
lucas_miranda committed
329
330
        generator = Model_D3(generator)
        generator = Model_D4(generator)
331
        generator = Model_B4(generator)
lucas_miranda's avatar
lucas_miranda committed
332
        generator = Model_D5(generator)
333
        generator = Model_B5(generator)
lucas_miranda's avatar
lucas_miranda committed
334
        x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
335

336
        # end-to-end autoencoder
337
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
338
        vae = Model(x, x_decoded_mean, name="SEQ_2_SEQ_VAE")
lucas_miranda's avatar
lucas_miranda committed
339

340
341
342
        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
        _generator = Model_D0(g)
343
        _generator = Model_B1(_generator)
344
        _generator = Model_D1(_generator)
345
        _generator = Model_B2(_generator)
346
        _generator = Model_D2(_generator)
347
        _generator = Model_B3(_generator)
348
349
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
350
        _generator = Model_B4(_generator)
351
        _generator = Model_D5(_generator)
352
        _generator = Model_B5(_generator)
353
354
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
355

356
357
358
        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)
359
360

        vae.compile(
361
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
362
363
        )

364
        return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback
365
366


367
368
class SEQ_2_SEQ_VAEP:
    def __init__(
369
370
371
372
373
374
375
376
377
378
        self,
        input_shape,
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
        loss="ELBO+MMD",
379
380
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
381
        prior="standard_normal",
382
383
384
385
386
387
388
389
390
391
392
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
        self.loss = loss
393
        self.prior = prior
394
395
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
396

397
398
399
400
401
402
        if self.prior == "standard_normal":
            self.prior = tfd.Independent(
                tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
                reinterpreted_batch_ndims=1,
            )

403
404
405
406
        assert (
            "ELBO" in self.loss or "MMD" in self.loss
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
            kernel_initializer=he_uniform(),
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E4 = Dense(
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
            kernel_constraint=UnitNorm(axis=1),
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
            kernel_initializer=Orthogonal(),
        )

        # Decoder layers
454
455
456
457
458
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
        Model_B5 = BatchNormalization()
459
460
461
        Model_D0 = DenseTranspose(
            Model_E5, activation="relu", output_dim=self.ENCODING,
        )
462
463
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        Model_D3 = RepeatVector(self.input_shape[1])
        Model_D4 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )

482
        # Define and instantiate encoder
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        x = Input(shape=self.input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E5(encoder)

497
        encoder = Dense(
498
            tfpl.IndependentNormal.params_size(self.ENCODING), activation=None
499
        )(encoder)
500

501
        # Define and control custom loss functions
502
        kl_warmup_callback = False
503
        if "ELBO" in self.loss:
504

505
            kl_beta = K.variable(1.0, name="kl_beta")
506
            kl_beta._trainable = False
507
            if self.kl_warmup:
508
509
510
511
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
512
                )
513

514
        z = tfpl.IndependentNormal(self.ENCODING)(encoder)
515

516
517
        if "ELBO" in self.loss:
            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
518

519
        mmd_warmup_callback = False
520
        if "MMD" in self.loss:
521

522
            mmd_beta = K.variable(1.0, name="mmd_beta")
523
            mmd_beta._trainable = False
524
525
526
527
528
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
529
530
                )

531
            z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
532

533
        # Define and instantiate generator
534
        generator = Model_D0(z)
535
        generator = Model_B1(generator)
536
        generator = Model_D1(generator)
537
        generator = Model_B2(generator)
538
        generator = Model_D2(generator)
539
        generator = Model_B3(generator)
540
541
        generator = Model_D3(generator)
        generator = Model_D4(generator)
542
        generator = Model_B4(generator)
543
        generator = Model_D5(generator)
544
        generator = Model_B5(generator)
545
546
547
        x_decoded_mean = TimeDistributed(
            Dense(self.input_shape[2]), name="vaep_reconstruction"
        )(generator)
548

549
        # Define and instantiate predictor
550
551
552
        predictor = Dense(
            self.ENCODING, activation="relu", kernel_initializer=he_uniform()
        )(z)
553
        predictor = BatchNormalization()(predictor)
554
555
556
        predictor = Dense(
            self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
557
        predictor = BatchNormalization()(predictor)
558
559
560
        predictor = Dense(
            self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
561
        predictor = BatchNormalization()(predictor)
562
        predictor = RepeatVector(self.input_shape[1])(predictor)
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
580
        predictor = BatchNormalization()(predictor)
581
582
583
        x_predicted_mean = TimeDistributed(
            Dense(self.input_shape[2]), name="vaep_prediction"
        )(predictor)
584

585
        # end-to-end autoencoder
586
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
587
        vaep = Model(
588
            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAEP"
589
        )
590
591
592
593

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
        _generator = Model_D0(g)
594
        _generator = Model_B1(_generator)
595
        _generator = Model_D1(_generator)
596
        _generator = Model_B2(_generator)
597
        _generator = Model_D2(_generator)
598
        _generator = Model_B3(_generator)
599
600
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
601
        _generator = Model_B4(_generator)
602
        _generator = Model_D5(_generator)
603
        _generator = Model_B5(_generator)
604
605
606
607
608
609
610
611
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)

        vaep.compile(
612
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
613
614
        )

615
        return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback
616
617


618
619
class SEQ_2_SEQ_MMVAEP:
    def __init__(
620
621
622
623
624
625
626
627
628
629
630
631
        self,
        input_shape,
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
        loss="ELBO+MMD",
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
632
        prior="standard_normal",
633
        number_of_components=1,
634
635
636
637
638
639
640
641
642
643
644
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
        self.loss = loss
645
        self.prior = prior
646
647
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
648
        self.number_of_components = number_of_components
649

650
        if self.prior == "standard_normal":
651
652
653
654
655
656
657
658
659
660
661
            self.prior = tfd.mixture.Mixture(
                tfd.categorical.Categorical(
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
                [
                    tfd.Independent(
                        tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
                        reinterpreted_batch_ndims=1,
                    )
                    for _ in range(self.number_of_components)
                ],
662
            )
663
664
665

        assert (
            "ELBO" in self.loss or "MMD" in self.loss
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
            kernel_initializer=he_uniform(),
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E4 = Dense(
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
            kernel_constraint=UnitNorm(axis=1),
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
            kernel_initializer=Orthogonal(),
        )

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
        Model_B5 = BatchNormalization()
        Model_D0 = DenseTranspose(
            Model_E5, activation="relu", output_dim=self.ENCODING,
        )
723
724
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
        Model_D3 = RepeatVector(self.input_shape[1])
        Model_D4 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )

        # Define and instantiate encoder
        x = Input(shape=self.input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E5(encoder)

758
759
760
761
762
763
764
765
        # Map encoder to a categorical distribution over the components
        zcat = Dense(self.number_of_components, activation="softmax")(encoder)

        # Map encoder to a dense layer representing the parameters of
        # the gaussian mixture latent space
        zgauss = Dense(
            tfpl.MixtureNormal.params_size(self.number_of_components, self.ENCODING),
            activation=None,
766
        )(encoder)
767
768
769
770
771
772
773
774
775
776
777
778
779
780

        # Define and control custom loss functions
        kl_warmup_callback = False
        if "ELBO" in self.loss:

            kl_beta = K.variable(1.0, name="kl_beta")
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
                )

781
        z = tfpl.MixtureNormal(self.number_of_components, self.ENCODING)(zgauss)
782

783
784
        if "ELBO" in self.loss:
            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
785
786
787
788
789
790
791
792
793
794
795
796
797

        mmd_warmup_callback = False
        if "MMD" in self.loss:

            mmd_beta = K.variable(1.0, name="mmd_beta")
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
                )

798
            z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
799
800
801
802
803
804
805
806
807
808
809
810
811
812

        # Define and instantiate generator
        generator = Model_D0(z)
        generator = Model_B1(generator)
        generator = Model_D1(generator)
        generator = Model_B2(generator)
        generator = Model_D2(generator)
        generator = Model_B3(generator)
        generator = Model_D3(generator)
        generator = Model_D4(generator)
        generator = Model_B4(generator)
        generator = Model_D5(generator)
        generator = Model_B5(generator)
        x_decoded_mean = TimeDistributed(
813
            Dense(self.input_shape[2]), name="vaep_reconstruction"
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
        )(generator)

        # Define and instantiate predictor
        predictor = Dense(
            self.ENCODING, activation="relu", kernel_initializer=he_uniform()
        )(z)
        predictor = BatchNormalization()(predictor)
        predictor = Dense(
            self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = Dense(
            self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = RepeatVector(self.input_shape[1])(predictor)
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
        predictor = BatchNormalization()(predictor)
        predictor = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )(predictor)
        predictor = BatchNormalization()(predictor)
        x_predicted_mean = TimeDistributed(
849
            Dense(self.input_shape[2]), name="vaep_prediction"
850
851
852
        )(predictor)

        # end-to-end autoencoder
853
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        gmvaep = Model(
            inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
        )

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
        _generator = Model_D0(g)
        _generator = Model_B1(_generator)
        _generator = Model_D1(_generator)
        _generator = Model_B2(_generator)
        _generator = Model_D2(_generator)
        _generator = Model_B3(_generator)
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
        _generator = Model_B4(_generator)
        _generator = Model_D5(_generator)
        _generator = Model_B5(_generator)
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)

        gmvaep.compile(
879
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
880
881
882
        )

        return encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback
lucas_miranda's avatar
lucas_miranda committed
883

884

885
886
# TODO:
#       - Gaussian Mixture + Categorical priors -> Deep Clustering
887
888
889
#           - prior of equal gaussians
#           - prior of equal gaussians + gaussian noise on the means (not exactly the same init)
#       - MCMC sampling (n>1) (already suported by tfp! we should try it)
890
891
#
# TODO (in the non-immediate future):
892
893
894
#       - free bits paper
#       - Attention mechanism for encoder / decoder (does it make sense?)
#       - Transformer encoder/decoder (does it make sense?)