Commit 536408f1 authored by lucas_miranda's avatar lucas_miranda
Browse files

Instanciated S2SVAE in models.py, including encoder, generator and full VAE

parent f137ab3f
......@@ -135,7 +135,8 @@ class SEQ_2_SEQ_VAE:
DENSE_2,
DROPOUT_RATE,
ENCODING,
learn_rate,
learn_rate=1e-5,
loss="ELBO+MMD",
):
self.input_shape = input_shape
self.CONV_filters = CONV_filters
......@@ -146,11 +147,12 @@ class SEQ_2_SEQ_VAE:
self.DROPOUT_RATE = DROPOUT_RATE
self.ENCODING = ENCODING
self.learn_rate = learn_rate
self.loss = loss
def build(self):
# Encoder Layers
Model_E0 = tf.keras.layers.Conv1D(
filters=CONV_filters,
filters=self.CONV_filters,
kernel_size=5,
strides=1,
padding="causal",
......@@ -158,7 +160,7 @@ class SEQ_2_SEQ_VAE:
)
Model_E1 = Bidirectional(
LSTM(
LSTM_units_1,
self.LSTM_units_1,
activation="tanh",
return_sequences=True,
kernel_constraint=UnitNorm(axis=0),
......@@ -166,16 +168,20 @@ class SEQ_2_SEQ_VAE:
)
Model_E2 = Bidirectional(
LSTM(
LSTM_units_2,
self.LSTM_units_2,
activation="tanh",
return_sequences=False,
kernel_constraint=UnitNorm(axis=0),
)
)
Model_E3 = Dense(DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0))
Model_E4 = Dense(DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0))
Model_E3 = Dense(
self.DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0)
)
Model_E4 = Dense(
self.DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0)
)
Model_E5 = Dense(
ENCODING,
self.ENCODING,
activation="relu",
kernel_constraint=UnitNorm(axis=1),
activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
......@@ -184,7 +190,7 @@ class SEQ_2_SEQ_VAE:
# Decoder layers
Model_D4 = Bidirectional(
LSTM(
LSTM_units_1,
self.LSTM_units_1,
activation="tanh",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
......@@ -192,7 +198,7 @@ class SEQ_2_SEQ_VAE:
)
Model_D5 = Bidirectional(
LSTM(
LSTM_units_1,
self.LSTM_units_1,
activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
......@@ -205,12 +211,12 @@ class SEQ_2_SEQ_VAE:
encoder = Model_E1(encoder)
encoder = Model_E2(encoder)
encoder = Model_E3(encoder)
encoder = Dropout(DROPOUT_RATE)(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder)
encoder = Model_E4(encoder)
encoder = Model_E5(encoder)
z_mean = Dense(ENCODING)(encoder)
z_log_sigma = Dense(ENCODING)(encoder)
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
if "ELBO" in self.loss:
z_mean, z_log_sigma = KLDivergenceLayer()([z_mean, z_log_sigma])
......@@ -221,11 +227,11 @@ class SEQ_2_SEQ_VAE:
z = MMDiscrepancyLayer()(z)
# Define and instanciate decoder
decoder = DenseTranspose(Model_E5, activation="relu", output_dim=ENCODING)(z)
decoder = DenseTranspose(Model_E4, activation="relu", output_dim=DENSE_2)(
decoder = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)(z)
decoder = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)(
decoder
)
decoder = DenseTranspose(Model_E3, activation="relu", output_dim=DENSE_1)(
decoder = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)(
decoder
)
decoder = RepeatVector(self.input_shape[1])(decoder)
......@@ -233,6 +239,21 @@ class SEQ_2_SEQ_VAE:
decoder = Model_D5(decoder)
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(decoder)
#generator
_decoder_input = Input(shape=(self.ENCODING,))
_decoder = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)(z)
_decoder = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)(
decoder
)
_decoder = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)(
decoder
)
_decoder = RepeatVector(self.input_shape[1])(decoder)
_decoder = Model_D4(decoder)
_decoder = Model_D5(decoder)
_x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(decoder)
generator = Model(_decoder_input, _x_decoded_mean)
# end-to-end autoencoder
vae = Model(x, x_decoded_mean)
......@@ -243,13 +264,7 @@ class SEQ_2_SEQ_VAE:
vae.compile(
loss=huber_loss,
optimizer=Adam(
lr=hp.Float(
"learning_rate",
min_value=1e-4,
max_value=1e-2,
sampling="LOG",
default=1e-3,
),
lr=self.learn_rate,
),
metrics=["mae"],
experimental_run_tf_function=False,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment