models.py 10.1 KB
Newer Older
1
2
3
# @author lucasmiranda42

from tensorflow.keras import Input, Model, Sequential
4
from tensorflow.keras.constraints import UnitNorm
5
6
from tensorflow.keras.layers import BatchNormalization, Bidirectional, Dense
from tensorflow.keras.layers import Dropout, Lambda, LSTM
7
8
9
from tensorflow.keras.layers import RepeatVector, TimeDistributed
from tensorflow.keras.losses import Huber
from tensorflow.keras.optimizers import Adam
10
from source.model_utils import *
11
12
13
14
import tensorflow as tf


class SEQ_2_SEQ_AE:
15
16
17
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
18
19
20
21
22
23
24
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
            filters=self.CONV_filters,
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
        )
45
        Model_E1 = Bidirectional(
46
            LSTM(
47
48
49
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
50
                kernel_constraint=UnitNorm(axis=0),
51
52
            )
        )
53
        Model_E2 = Bidirectional(
54
            LSTM(
55
56
57
                self.LSTM_units_2,
                activation="tanh",
                return_sequences=False,
58
                kernel_constraint=UnitNorm(axis=0),
59
60
            )
        )
61
        Model_E3 = Dense(
62
            self.DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0)
63
64
        )
        Model_E4 = Dense(
65
            self.DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0)
66
        )
67
68
69
        Model_E5 = Dense(
            self.ENCODING,
            activation="relu",
70
            kernel_constraint=UnitNorm(axis=1),
71
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
72
73
74
        )

        # Decoder layers
75
76
77
        Model_D0 = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)
78
        Model_D3 = RepeatVector(self.input_shape[1])
79
        Model_D4 = Bidirectional(
80
            LSTM(
81
82
83
                self.LSTM_units_1,
                activation="tanh",
                return_sequences=True,
84
                kernel_constraint=UnitNorm(axis=1),
85
86
            )
        )
87
        Model_D5 = Bidirectional(
88
            LSTM(
89
90
91
                self.LSTM_units_1,
                activation="sigmoid",
                return_sequences=True,
92
                kernel_constraint=UnitNorm(axis=1),
93
94
95
96
            )
        )

        # Define and instanciate encoder
lucas_miranda's avatar
lucas_miranda committed
97
        encoder = Sequential(name="SEQ_2_SEQ_Encoder")
98
        encoder.add(Input(shape=self.input_shape[1:]))
99
        encoder.add(Model_E0)
100
        encoder.add(BatchNormalization())
101
        encoder.add(Model_E1)
102
        encoder.add(BatchNormalization())
103
        encoder.add(Model_E2)
104
        encoder.add(BatchNormalization())
105
        encoder.add(Model_E3)
106
        encoder.add(BatchNormalization())
107
108
        encoder.add(Dropout(self.DROPOUT_RATE))
        encoder.add(Model_E4)
109
        encoder.add(BatchNormalization())
110
111
112
        encoder.add(Model_E5)

        # Define and instanciate decoder
lucas_miranda's avatar
lucas_miranda committed
113
        decoder = Sequential(name="SEQ_2_SEQ_Decoder")
114
        decoder.add(Model_D0)
115
        encoder.add(BatchNormalization())
116
        decoder.add(Model_D1)
117
        encoder.add(BatchNormalization())
118
        decoder.add(Model_D2)
119
        encoder.add(BatchNormalization())
120
        decoder.add(Model_D3)
121
        encoder.add(BatchNormalization())
122
        decoder.add(Model_D4)
123
        encoder.add(BatchNormalization())
124
125
126
        decoder.add(Model_D5)
        decoder.add(TimeDistributed(Dense(self.input_shape[2])))

lucas_miranda's avatar
lucas_miranda committed
127
        model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
128
129
130

        model.compile(
            loss=Huber(reduction="sum", delta=100.0),
131
            optimizer=Adam(lr=self.learn_rate, clipvalue=0.5,),
132
133
134
            metrics=["mae"],
        )

lucas_miranda's avatar
lucas_miranda committed
135
        return encoder, decoder, model
136
137
138


class SEQ_2_SEQ_VAE:
139
140
141
    def __init__(
        self,
        input_shape,
lucas_miranda's avatar
lucas_miranda committed
142
143
144
145
146
147
148
        CONV_filters=256,
        LSTM_units_1=256,
        LSTM_units_2=64,
        DENSE_2=64,
        DROPOUT_RATE=0.25,
        ENCODING=32,
        learn_rate=1e-3,
149
        loss="ELBO+MMD",
150
151
152
153
154
155
156
157
158
159
    ):
        self.input_shape = input_shape
        self.CONV_filters = CONV_filters
        self.LSTM_units_1 = LSTM_units_1
        self.LSTM_units_2 = LSTM_units_2
        self.DENSE_1 = LSTM_units_2
        self.DENSE_2 = DENSE_2
        self.DROPOUT_RATE = DROPOUT_RATE
        self.ENCODING = ENCODING
        self.learn_rate = learn_rate
160
        self.loss = loss
161
162
163
164

    def build(self):
        # Encoder Layers
        Model_E0 = tf.keras.layers.Conv1D(
165
            filters=self.CONV_filters,
166
167
168
169
170
            kernel_size=5,
            strides=1,
            padding="causal",
            activation="relu",
        )
171
        Model_E1 = Bidirectional(
172
            LSTM(
173
                self.LSTM_units_1,
174
175
                activation="tanh",
                return_sequences=True,
176
                kernel_constraint=UnitNorm(axis=0),
177
178
            )
        )
179
        Model_E2 = Bidirectional(
180
            LSTM(
181
                self.LSTM_units_2,
182
183
                activation="tanh",
                return_sequences=False,
184
                kernel_constraint=UnitNorm(axis=0),
185
186
            )
        )
187
        Model_E3 = Dense(
188
            self.DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0)
189
190
        )
        Model_E4 = Dense(
191
            self.DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0)
192
        )
193
        Model_E5 = Dense(
194
            self.ENCODING,
195
            activation="relu",
196
            kernel_constraint=UnitNorm(axis=1),
197
            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
198
199
200
        )

        # Decoder layers
lucas_miranda's avatar
lucas_miranda committed
201

202
203
204
        Model_D0 = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)
        Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)
        Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)
lucas_miranda's avatar
lucas_miranda committed
205
        Model_D3 = RepeatVector(self.input_shape[1])
206
        Model_D4 = Bidirectional(
207
            LSTM(
208
                self.LSTM_units_1,
209
210
                activation="tanh",
                return_sequences=True,
211
                kernel_constraint=UnitNorm(axis=1),
212
213
            )
        )
214
        Model_D5 = Bidirectional(
215
            LSTM(
216
                self.LSTM_units_1,
217
218
                activation="sigmoid",
                return_sequences=True,
219
                kernel_constraint=UnitNorm(axis=1),
220
221
222
223
224
            )
        )

        # Define and instanciate encoder
        x = Input(shape=self.input_shape[1:])
225
        encoder = Model_E0(x)
226
        encoder = BatchNormalization()(encoder)
227
        encoder = Model_E1(encoder)
228
        encoder = BatchNormalization()(encoder)
229
        encoder = Model_E2(encoder)
230
        encoder = BatchNormalization()(encoder)
231
        encoder = Model_E3(encoder)
232
        encoder = BatchNormalization()(encoder)
233
        encoder = Dropout(self.DROPOUT_RATE)(encoder)
234
        encoder = Model_E4(encoder)
235
        encoder = BatchNormalization()(encoder)
236
237
        encoder = Model_E5(encoder)

238
239
        z_mean = Dense(self.ENCODING)(encoder)
        z_log_sigma = Dense(self.ENCODING)(encoder)
240
241
242
243
244
245
246
247
248

        if "ELBO" in self.loss:
            z_mean, z_log_sigma = KLDivergenceLayer()([z_mean, z_log_sigma])

        z = Lambda(sampling)([z_mean, z_log_sigma])

        if "MMD" in self.loss:
            z = MMDiscrepancyLayer()(z)

lucas_miranda's avatar
lucas_miranda committed
249
250
        # Define and instanciate generator
        generator = Model_D0(z)
251
        generator = BatchNormalization()(generator)
lucas_miranda's avatar
lucas_miranda committed
252
        generator = Model_D1(generator)
253
        generator = BatchNormalization()(generator)
lucas_miranda's avatar
lucas_miranda committed
254
        generator = Model_D2(generator)
255
        generator = BatchNormalization()(generator)
lucas_miranda's avatar
lucas_miranda committed
256
        generator = Model_D3(generator)
257
        generator = BatchNormalization()(generator)
lucas_miranda's avatar
lucas_miranda committed
258
        generator = Model_D4(generator)
259
        generator = BatchNormalization()(generator)
lucas_miranda's avatar
lucas_miranda committed
260
261
        generator = Model_D5(generator)
        x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
262

263
        # end-to-end autoencoder
lucas_miranda's avatar
lucas_miranda committed
264
        encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
265
        vae = Model(x, x_decoded_mean, name="SEQ_2_SEQ_VAE")
lucas_miranda's avatar
lucas_miranda committed
266

267
268
269
270
271
272
273
274
275
276
        # Build generator as a separate entity
        g = Input(shape=self.ENCODING)
        _generator = Model_D0(g)
        _generator = Model_D1(_generator)
        _generator = Model_D2(_generator)
        _generator = Model_D3(_generator)
        _generator = Model_D4(_generator)
        _generator = Model_D5(_generator)
        _x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
        generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
277
278
279
280
281
282
283

        def huber_loss(x, x_decoded_mean):
            huber_loss = Huber(reduction="sum", delta=100.0)
            return self.input_shape[1:] * huber_loss(x, x_decoded_mean)

        vae.compile(
            loss=huber_loss,
lucas_miranda's avatar
lucas_miranda committed
284
            optimizer=Adam(lr=self.learn_rate,),
285
286
287
288
            metrics=["mae"],
            experimental_run_tf_function=False,
        )

289
        return encoder, generator, vae
290
291


292
class SEQ_2_SEQ_VAME:
293
294
295
    pass


296
class SEQ_2_SEQ_MMVAE:
297
    pass
lucas_miranda's avatar
lucas_miranda committed
298

299

lucas_miranda's avatar
lucas_miranda committed
300
# TODO:
301
302
#      - Baseline networks (done!)
#      - Initial Convnet switch (done!)
303
#      - Bidirectional LSTM switches (done!)
304
#      - Change LSTMs for GRU (done!)
305
#      - Tied/Untied weights (done!)
306
#      - orthogonal/non-orthogonal weights (done!)
307
#      - Unit Norm constraint (done!)
308
309
310
#      - add batch normalization
#      - add He initialization

311
# TODO next:
lucas_miranda's avatar
lucas_miranda committed
312
#      - VAE loss function (though this should be analysed later on taking the encodings into account)
313
#      - Smaller input sliding window (10-15 frames)