models.py 14.7 KB
Newer Older
1
2
# @author lucasmiranda42

3
from tensorflow.keras import backend as K
4
from tensorflow.keras import Input, Model, Sequential
5
from tensorflow.keras.activations import softplus
6
from tensorflow.keras.callbacks import LambdaCallback
7
from tensorflow.keras.constraints import UnitNorm
8
from tensorflow.keras.initializers import he_uniform, Orthogonal
9
from tensorflow.keras.layers import BatchNormalization, Bidirectional
10
11
from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
12
from tensorflow.keras.losses import Huber
13
from tensorflow.keras.optimizers import Adam
14
from source.model_utils import *
15
import tensorflow as tf
16
17
18
19
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers
20
21
22


class SEQ_2_SEQ_AE:
23
24
25
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
26
27
28
29
30
31
32
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
52
            kernel_initializer=he_uniform(),
53
        )
54
        Model_E1 = Bidirectional(
55
            LSTM(
56
57
58
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
59
                kernel_constraint=UnitNorm(axis=0),
60
61
            )
        )
62
        Model_E2 = Bidirectional(
63
            LSTM(
64
65
66
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
67
                kernel_constraint=UnitNorm(axis=0),
68
69
            )
        )
70
        Model_E3 = Dense(
71
72
73
74
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
75
76
        )
        Model_E4 = Dense(
77
78
79
80
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
81
        )
82
83
84
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
85
            kernel_constraint=UnitNorm(axis=1),
86
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
87
            kernel_initializer=Orthogonal(),
88
89
90
        )

        # Decoder layers
91
        Model_D0 = DenseTranspose(
92
            Model_E5, activation="relu", output_dim=self.ENCODING,
93
        )
94
95
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
96
        Model_D3 = RepeatVector(self.input_shape[1])
97
        Model_D4 = Bidirectional(
98
            LSTM(
99
100
101
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
102
                kernel_constraint=UnitNorm(axis=1),
103
104
            )
        )
105
        Model_D5 = Bidirectional(
106
            LSTM(
107
108
109
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
110
                kernel_constraint=UnitNorm(axis=1),
111
112
113
            )
        )

114
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
115
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
116
        encoder.add(Input(shape=self.input_shape[1:]))
117
        encoder.add(Model_E0)
118
        encoder.add(BatchNormalization())
119
        encoder.add(Model_E1)
120
        encoder.add(BatchNormalization())
121
        encoder.add(Model_E2)
122
        encoder.add(BatchNormalization())
123
        encoder.add(Model_E3)
124
        encoder.add(BatchNormalization())
125
126
        encoder.add(Dropout(self.DROPOUT_RATE))
        encoder.add(Model_E4)
127
        encoder.add(BatchNormalization())
128
129
        encoder.add(Model_E5)

130
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
131
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
132
        decoder.add(Model_D0)
133
        decoder.add(BatchNormalization())
134
        decoder.add(Model_D1)
135
        decoder.add(BatchNormalization())
136
        decoder.add(Model_D2)
137
        decoder.add(BatchNormalization())
138
        decoder.add(Model_D3)
139
        decoder.add(Model_D4)
140
        decoder.add(BatchNormalization())
141
142
143
        decoder.add(Model_D5)
        decoder.add(TimeDistributed(Dense(self.input_shape[2])))

lucas_miranda's avatar
lucas_miranda committed
144
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
145
146

        model.compile(
147
            loss=Huber(reduction="sum", delta=100.0),
148
            optimizer=Adam(lr=self.learn_rate, clipvalue=0.5,),
149
150
151
            metrics=["mae"],
        )

lucas_miranda's avatar
lucas_miranda committed
152
        return encoder, decoder, model
153
154


155
class SEQ_2_SEQ_GMVAE:
156
    def __init__(
157
158
159
160
161
162
163
164
165
166
167
168
        self,
        input_shape,
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
        loss="ELBO+MMD",
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
169
        prior="standard_normal",
170
        number_of_components=1,
171
        predictor=True,
172
173
174
175
176
177
178
179
180
181
182
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
        self.loss = loss
183
        self.prior = prior
184
185
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
186
        self.number_of_components = number_of_components
187
        self.predictor = predictor
188

189
        if self.prior == "standard_normal":
190
191
192
193
194
195
196
197
198
199
200
            self.prior = tfd.mixture.Mixture(
                tfd.categorical.Categorical(
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
                [
                    tfd.Independent(
                        tfd.Normal(loc=tf.zeros(self.ENCODING), scale=1),
                        reinterpreted_batch_ndims=1,
                    )
                    for _ in range(self.number_of_components)
                ],
201
            )
202
203
204

        assert (
            "ELBO" in self.loss or "MMD" in self.loss
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
            kernel_initializer=he_uniform(),
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E4 = Dense(
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
251
252
        Model_D1 = Dense(
            self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
253
        )
254
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        Model_D3 = RepeatVector(self.input_shape[1])
        Model_D4 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )

        # Define and instantiate encoder
        x = Input(shape=self.input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
286

287
288
289
290
291
        z_cat = Dense(self.number_of_components, activation="softmax")(encoder)
        z_gauss = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            ),
292
            activation=None,
293
        )(encoder)
294
295
296
297
298
299
300
301
302
303
304
305
306
307

        # Define and control custom loss functions
        kl_warmup_callback = False
        if "ELBO" in self.loss:

            kl_beta = K.variable(1.0, name="kl_beta")
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
                )

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
        z = tfpl.DistributionLambda(
            lambda gauss: tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(probs=gauss[0],),
                components=[
                    tfd.Independent(
                        tfd.Normal(
                            loc=gauss[1][..., : self.ENCODING, k],
                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
        )([z_cat, z_gauss])
325

326
327
        if "ELBO" in self.loss:
            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
328
329
330
331
332
333
334
335
336
337
338
339
340

        mmd_warmup_callback = False
        if "MMD" in self.loss:

            mmd_beta = K.variable(1.0, name="mmd_beta")
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
                )

341
            z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
342
343

        # Define and instantiate generator
344
        generator = Model_D1(z)
345
346
        generator = Model_B1(generator)
        generator = Model_D2(generator)
347
        generator = Model_B2(generator)
348
349
        generator = Model_D3(generator)
        generator = Model_D4(generator)
350
        generator = Model_B3(generator)
351
        generator = Model_D5(generator)
352
        generator = Model_B4(generator)
353
        x_decoded_mean = TimeDistributed(
354
            Dense(self.input_shape[2]), name="vaep_reconstruction"
355
356
        )(generator)

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        if self.predictor:
            # Define and instantiate predictor
            predictor = Dense(
                self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
            )(z)
            predictor = BatchNormalization()(predictor)
            predictor = Dense(
                self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
            )(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = RepeatVector(self.input_shape[1])(predictor)
            predictor = Bidirectional(
                LSTM(
                    self.LSTM_units_1,
                    activation="tanh",
                    return_sequences=True,
                    kernel_constraint=UnitNorm(axis=1),
                )
            )(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = Bidirectional(
                LSTM(
                    self.LSTM_units_1,
                    activation="sigmoid",
                    return_sequences=True,
                    kernel_constraint=UnitNorm(axis=1),
                )
            )(predictor)
            predictor = BatchNormalization()(predictor)
            x_predicted_mean = TimeDistributed(
                Dense(self.input_shape[2]), name="vaep_prediction"
            )(predictor)
389
390

        # end-to-end autoencoder
391
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
392
        grouper = Model(x, z_cat, name="Deep_Gaussian_Mixture_clustering")
393
        gmvaep = Model(
394
395
396
397
398
            inputs=x,
            outputs=(
                [x_decoded_mean, x_predicted_mean] if self.predictor else x_decoded_mean
            ),
            name="SEQ_2_SEQ_VAE",
399
400
401
402
        )

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
403
        _generator = Model_D1(g)
404
405
        _generator = Model_B1(_generator)
        _generator = Model_D2(_generator)
406
        _generator = Model_B2(_generator)
407
408
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
409
        _generator = Model_B3(_generator)
410
        _generator = Model_D5(_generator)
411
        _generator = Model_B4(_generator)
412
413
414
415
416
417
418
419
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)

        gmvaep.compile(
420
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
421
422
        )

423
424
425
426
427
428
429
430
        return (
            encoder,
            generator,
            grouper,
            gmvaep,
            kl_warmup_callback,
            mmd_warmup_callback,
        )
lucas_miranda's avatar
lucas_miranda committed
431

432

433
# TODO:
434
#       - Try Bayesian nets!
435
#       - MCMC sampling (n>1) (already suported by tfp! we should try it)
436
437
#
# TODO (in the non-immediate future):
438
439
440
#       - free bits paper
#       - Attention mechanism for encoder / decoder (does it make sense?)
#       - Transformer encoder/decoder (does it make sense?)