models.py 15.8 KB
Newer Older
1
2
# @author lucasmiranda42

3
from tensorflow.keras import backend as K
4
from tensorflow.keras import Input, Model, Sequential
5
from tensorflow.keras.activations import softplus
6
from tensorflow.keras.callbacks import LambdaCallback
7
from tensorflow.keras.constraints import UnitNorm
8
from tensorflow.keras.initializers import he_uniform, Orthogonal, RandomNormal
9
from tensorflow.keras.layers import BatchNormalization, Bidirectional
10
11
from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
12
from tensorflow.keras.losses import Huber
13
from tensorflow.keras.optimizers import Adam
14
from source.model_utils import *
15
import tensorflow as tf
16
17
18
19
import tensorflow_probability as tfp

tfd = tfp.distributions
tfpl = tfp.layers
20
21
22


class SEQ_2_SEQ_AE:
23
24
25
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
26
27
28
29
        units_conv=256,
        units_lstm=256,
        units_dense2=64,
        dropout_rate=0.25,
30
        encoding=16,
lucas_miranda's avatar
lucas_miranda committed
31
        learning_rate=1e-3,
32
33
    ):
        self.input_shape = input_shape
lucas_miranda's avatar
lucas_miranda committed
34
35
36
37
38
39
40
41
        self.CONV_filters = units_conv
        self.LSTM_units_1 = units_lstm
        self.LSTM_units_2 = int(units_lstm / 2)
        self.DENSE_1 = int(units_lstm / 2)
        self.DENSE_2 = units_dense2
        self.DROPOUT_RATE = dropout_rate
        self.ENCODING = encoding
        self.learn_rate = learning_rate
42
43
44
45
46
47
48
49
50

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
51
            kernel_initializer=he_uniform(),
52
        )
53
        Model_E1 = Bidirectional(
54
            LSTM(
55
56
57
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
58
                kernel_constraint=UnitNorm(axis=0),
59
60
            )
        )
61
        Model_E2 = Bidirectional(
62
            LSTM(
63
64
65
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
66
                kernel_constraint=UnitNorm(axis=0),
67
68
            )
        )
69
        Model_E3 = Dense(
70
71
72
73
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
74
75
        )
        Model_E4 = Dense(
76
77
78
79
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
80
        )
81
82
83
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
84
            kernel_constraint=UnitNorm(axis=1),
85
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
86
            kernel_initializer=Orthogonal(),
87
88
89
        )

        # Decoder layers
90
        Model_D0 = DenseTranspose(
91
            Model_E5, activation="relu", output_dim=self.ENCODING,
92
        )
93
94
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
95
        Model_D3 = RepeatVector(self.input_shape[1])
96
        Model_D4 = Bidirectional(
97
            LSTM(
98
99
100
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
101
                kernel_constraint=UnitNorm(axis=1),
102
103
            )
        )
104
        Model_D5 = Bidirectional(
105
            LSTM(
106
107
108
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
109
                kernel_constraint=UnitNorm(axis=1),
110
111
112
            )
        )

113
        # Define and instantiate encoder
lucas_miranda's avatar
lucas_miranda committed
114
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
115
        encoder.add(Input(shape=self.input_shape[1:]))
116
        encoder.add(Model_E0)
117
        encoder.add(BatchNormalization())
118
        encoder.add(Model_E1)
119
        encoder.add(BatchNormalization())
120
        encoder.add(Model_E2)
121
        encoder.add(BatchNormalization())
122
        encoder.add(Model_E3)
123
        encoder.add(BatchNormalization())
124
125
        encoder.add(Dropout(self.DROPOUT_RATE))
        encoder.add(Model_E4)
126
        encoder.add(BatchNormalization())
127
128
        encoder.add(Model_E5)

129
        # Define and instantiate decoder
lucas_miranda's avatar
lucas_miranda committed
130
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
131
        decoder.add(Model_D0)
132
        decoder.add(BatchNormalization())
133
        decoder.add(Model_D1)
134
        decoder.add(BatchNormalization())
135
        decoder.add(Model_D2)
136
        decoder.add(BatchNormalization())
137
        decoder.add(Model_D3)
138
        decoder.add(Model_D4)
139
        decoder.add(BatchNormalization())
140
141
142
        decoder.add(Model_D5)
        decoder.add(TimeDistributed(Dense(self.input_shape[2])))

lucas_miranda's avatar
lucas_miranda committed
143
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
144
145

        model.compile(
146
            loss=Huber(reduction="sum", delta=100.0),
147
            optimizer=Adam(lr=self.learn_rate, clipvalue=0.5,),
148
149
150
            metrics=["mae"],
        )

lucas_miranda's avatar
lucas_miranda committed
151
        return encoder, decoder, model
152
153


154
class SEQ_2_SEQ_GMVAE:
155
    def __init__(
156
157
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
158
159
160
161
        units_conv=256,
        units_lstm=256,
        units_dense2=64,
        dropout_rate=0.25,
162
        encoding=16,
lucas_miranda's avatar
lucas_miranda committed
163
        learning_rate=1e-3,
164
165
166
        loss="ELBO+MMD",
        kl_warmup_epochs=0,
        mmd_warmup_epochs=0,
167
        prior="standard_normal",
168
        number_of_components=1,
169
        predictor=True,
170
171
    ):
        self.input_shape = input_shape
lucas_miranda's avatar
lucas_miranda committed
172
173
174
175
176
177
178
179
        self.CONV_filters = units_conv
        self.LSTM_units_1 = units_lstm
        self.LSTM_units_2 = int(units_lstm / 2)
        self.DENSE_1 = int(units_lstm / 2)
        self.DENSE_2 = units_dense2
        self.DROPOUT_RATE = dropout_rate
        self.ENCODING = encoding
        self.learn_rate = learning_rate
180
        self.loss = loss
181
        self.prior = prior
182
183
        self.kl_warmup = kl_warmup_epochs
        self.mmd_warmup = mmd_warmup_epochs
184
        self.number_of_components = number_of_components
185
        self.predictor = predictor
186

187
        if self.prior == "standard_normal":
188
189
190
191
192
193
            self.prior = tfd.mixture.Mixture(
                tfd.categorical.Categorical(
                    probs=tf.ones(self.number_of_components) / self.number_of_components
                ),
                [
                    tfd.Independent(
194
195
196
197
198
199
                        tfd.Normal(
                            loc=tf.random.normal(
                                shape=[self.ENCODING], stddev=1 / self.ENCODING
                            ),
                            scale=1,
                        ),
200
201
202
203
                        reinterpreted_batch_ndims=1,
                    )
                    for _ in range(self.number_of_components)
                ],
204
            )
205
206
207

        assert (
            "ELBO" in self.loss or "MMD" in self.loss
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
            kernel_initializer=he_uniform(),
        )
        Model_E1 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E2 = Bidirectional(
            LSTM(
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
                kernel_constraint=UnitNorm(axis=0),
            )
        )
        Model_E3 = Dense(
            self.DENSE_1,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )
        Model_E4 = Dense(
            self.DENSE_2,
            activation="relu",
            kernel_constraint=UnitNorm(axis=0),
            kernel_initializer=he_uniform(),
        )

        # Decoder layers
        Model_B1 = BatchNormalization()
        Model_B2 = BatchNormalization()
        Model_B3 = BatchNormalization()
        Model_B4 = BatchNormalization()
254
255
        Model_D1 = Dense(
            self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
256
        )
257
258
259
        Model_D2 = Dense(
            self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
        )
260
261
262
        Model_D3 = RepeatVector(self.input_shape[1])
        Model_D4 = Bidirectional(
            LSTM(
263
                self.LSTM_units_2,
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                activation="tanh",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )
        Model_D5 = Bidirectional(
            LSTM(
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
                kernel_constraint=UnitNorm(axis=1),
            )
        )

        # Define and instantiate encoder
        x = Input(shape=self.input_shape[1:])
        encoder = Model_E0(x)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E1(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E2(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Model_E3(encoder)
        encoder = BatchNormalization()(encoder)
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
        encoder = Model_E4(encoder)
        encoder = BatchNormalization()(encoder)
291

292
        z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
293
294
295
296
        z_gauss = Dense(
            tfpl.IndependentNormal.params_size(
                self.ENCODING * self.number_of_components
            ),
297
            activation=None,
298
        )(encoder)
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
        z = tfpl.DistributionLambda(
            lambda gauss: tfd.mixture.Mixture(
                cat=tfd.categorical.Categorical(probs=gauss[0],),
                components=[
                    tfd.Independent(
                        tfd.Normal(
                            loc=gauss[1][..., : self.ENCODING, k],
                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
                        ),
                        reinterpreted_batch_ndims=1,
                    )
                    for k in range(self.number_of_components)
                ],
            ),
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
        )([z_cat, z_gauss])
317

318
319
        # Define and control custom loss functions
        kl_warmup_callback = False
320
        if "ELBO" in self.loss:
321
322
323
324
325
326
327
328
329
330

            kl_beta = K.variable(1.0, name="kl_beta")
            kl_beta._trainable = False
            if self.kl_warmup:
                kl_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        kl_beta, K.min([epoch / self.kl_warmup, 1])
                    )
                )

331
            z = KLDivergenceLayer(self.prior, weight=kl_beta)(z)
332
333
334
335
336
337
338
339
340
341
342
343
344

        mmd_warmup_callback = False
        if "MMD" in self.loss:

            mmd_beta = K.variable(1.0, name="mmd_beta")
            mmd_beta._trainable = False
            if self.mmd_warmup:
                mmd_warmup_callback = LambdaCallback(
                    on_epoch_begin=lambda epoch, logs: K.set_value(
                        mmd_beta, K.min([epoch / self.mmd_warmup, 1])
                    )
                )

345
            z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
346

347
348
349
350
351
352
353
354
355
356
357
358
        # z = Latent_space_control()(z, z_gauss, z_cat)

        # Latent space callback to control dead (zero) dimensions in the latent space
        dead_neuron_rate_callback = LambdaCallback(
            on_epoch_end=lambda epoch, logs: tf.math.zero_fraction(z_gauss)
        )

        # Latent space callback to control the latent silhouette clustering index
        silhouette_callback = LambdaCallback(
            on_epoch_end=tf.numpy_function(silhouette_score, [z, tf.math.argmax(z_cat, axis=1)], tf.float32)
        )

359
        # Define and instantiate generator
360
        generator = Model_D1(z)
361
362
        generator = Model_B1(generator)
        generator = Model_D2(generator)
363
        generator = Model_B2(generator)
364
365
        generator = Model_D3(generator)
        generator = Model_D4(generator)
366
        generator = Model_B3(generator)
367
        generator = Model_D5(generator)
368
        generator = Model_B4(generator)
369
        x_decoded_mean = TimeDistributed(
370
            Dense(self.input_shape[2]), name="vaep_reconstruction"
371
372
        )(generator)

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        if self.predictor:
            # Define and instantiate predictor
            predictor = Dense(
                self.DENSE_2, activation="relu", kernel_initializer=he_uniform()
            )(z)
            predictor = BatchNormalization()(predictor)
            predictor = Dense(
                self.DENSE_1, activation="relu", kernel_initializer=he_uniform()
            )(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = RepeatVector(self.input_shape[1])(predictor)
            predictor = Bidirectional(
                LSTM(
                    self.LSTM_units_1,
                    activation="tanh",
                    return_sequences=True,
                    kernel_constraint=UnitNorm(axis=1),
                )
            )(predictor)
            predictor = BatchNormalization()(predictor)
            predictor = Bidirectional(
                LSTM(
                    self.LSTM_units_1,
                    activation="sigmoid",
                    return_sequences=True,
                    kernel_constraint=UnitNorm(axis=1),
                )
            )(predictor)
            predictor = BatchNormalization()(predictor)
            x_predicted_mean = TimeDistributed(
                Dense(self.input_shape[2]), name="vaep_prediction"
            )(predictor)
405
406

        # end-to-end autoencoder
407
        encoder = Model(x, z, name="SEQ_2_SEQ_VEncoder")
408
        grouper = Model(x, z_cat, name="Deep_Gaussian_Mixture_clustering")
409
        gmvaep = Model(
410
411
412
413
414
            inputs=x,
            outputs=(
                [x_decoded_mean, x_predicted_mean] if self.predictor else x_decoded_mean
            ),
            name="SEQ_2_SEQ_VAE",
415
416
417
418
        )

        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
419
        _generator = Model_D1(g)
420
421
        _generator = Model_B1(_generator)
        _generator = Model_D2(_generator)
422
        _generator = Model_B2(_generator)
423
424
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
425
        _generator = Model_B3(_generator)
426
        _generator = Model_D5(_generator)
427
        _generator = Model_B4(_generator)
428
429
430
431
432
433
434
435
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")

        def huber_loss(x_, x_decoded_mean_):
            huber = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber(x_, x_decoded_mean_)

        gmvaep.compile(
436
            loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
437
438
        )

439
440
441
442
443
        return (
            encoder,
            generator,
            grouper,
            gmvaep,
444
445
            dead_neuron_rate_callback,
            silhouette_callback,
446
447
448
            kl_warmup_callback,
            mmd_warmup_callback,
        )
lucas_miranda's avatar
lucas_miranda committed
449

450

451
452
453
# TODO:
#       - latent space metrics to control overregulatization (turned off dimensions). Useful for warmup tuning
#       - Clustering metrics for model selection and aid training (eg early stopping)
454
#           - Silhouette / likelihood (AIC / BIC) / classifier accuracy metrics
455
456
#       - design clustering-conscious hyperparameter tuing pipeline

457
# TODO (in the non-immediate future):
458
#       - Try Bayesian nets!
459
#       - MCMC sampling (n>1) (already suported by tfp! we should try it)
460
461
462
#       - free bits paper
#       - Attention mechanism for encoder / decoder (does it make sense?)
#       - Transformer encoder/decoder (does it make sense?)