Commit 8549f828 authored by lucas_miranda's avatar lucas_miranda
Browse files

Drafted skeleton of SEQ2SEQ_VAEP (Variational AutoEncoding Predictor) in models.py

parent 8ec34ddd
This diff is collapsed.
This diff is collapsed.
......@@ -317,16 +317,16 @@ class SEQ_2_SEQ_VAE:
class SEQ_2_SEQ_VAEP:
def __init__(
self,
input_shape,
CONV_filters=256,
LSTM_units_1=256,
LSTM_units_2=64,
DENSE_2=64,
DROPOUT_RATE=0.25,
ENCODING=32,
learn_rate=1e-3,
loss="ELBO+MMD",
self,
input_shape,
CONV_filters=256,
LSTM_units_1=256,
LSTM_units_2=64,
DENSE_2=64,
DROPOUT_RATE=0.25,
ENCODING=32,
learn_rate=1e-3,
loss="ELBO+MMD",
):
self.input_shape = input_shape
self.CONV_filters = CONV_filters
......@@ -386,12 +386,16 @@ class SEQ_2_SEQ_VAEP:
)
# Decoder layers
Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
Model_B5 = BatchNormalization()
Model_D0 = DenseTranspose(
Model_E5, activation="relu", output_dim=self.ENCODING,
)
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2, )
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1, )
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional(
LSTM(
......@@ -438,34 +442,64 @@ class SEQ_2_SEQ_VAEP:
# Define and instanciate generator
generator = Model_D0(z)
generator = BatchNormalization()(generator)
generator = Model_B1(generator)
generator = Model_D1(generator)
generator = BatchNormalization()(generator)
generator = Model_B2(generator)
generator = Model_D2(generator)
generator = BatchNormalization()(generator)
generator = Model_B3(generator)
generator = Model_D3(generator)
generator = BatchNormalization()(generator)
generator = Model_B4(generator)
generator = Model_D4(generator)
generator = BatchNormalization()(generator)
generator = Model_B5(generator)
generator = Model_D5(generator)
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
# Define and instanciate predictor
predictor = Model_D0(z)
predictor = Model_B1(predictor)
predictor = Model_D1(predictor)
predictor = Model_B2(predictor)
predictor = Model_D2(predictor)
predictor = Model_B3(predictor)
predictor = Model_D3(predictor)
predictor = Model_B4(predictor)
predictor = Bidirectional(
LSTM(
self.LSTM_units_1,
activation="tanh",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
)
)(predictor)
predictor = BatchNormalization()(predictor)
predictor = Bidirectional(
LSTM(
self.LSTM_units_1,
activation="sigmoid",
return_sequences=True,
kernel_constraint=UnitNorm(axis=1),
)
)(predictor)
x_predicted_mean = TimeDistributed(Dense(self.input_shape[2]))(predictor)
# end-to-end autoencoder
encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
vaep = Model(x, x_decoded_mean, name="SEQ_2_SEQ_VAE")
vaep = Model(
inputs=x, outputs=[x_decoded_mean, x_predicted_mean], name="SEQ_2_SEQ_VAE"
)
# Build generator as a separate entity
g = Input(shape=self.ENCODING)
_generator = Model_D0(g)
_generator = BatchNormalization()(_generator)
_generator = Model_B1(_generator)
_generator = Model_D1(_generator)
_generator = BatchNormalization()(_generator)
_generator = Model_B2(_generator)
_generator = Model_D2(_generator)
_generator = BatchNormalization()(_generator)
_generator = Model_B3(_generator)
_generator = Model_D3(_generator)
_generator = BatchNormalization()(_generator)
_generator = Model_B4(_generator)
_generator = Model_D4(_generator)
_generator = BatchNormalization()(_generator)
_generator = Model_B5(_generator)
_generator = Model_D5(_generator)
_x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
......@@ -476,7 +510,7 @@ class SEQ_2_SEQ_VAEP:
vaep.compile(
loss=huber_loss,
optimizer=Adam(lr=self.learn_rate, ),
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
experimental_run_tf_function=False,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment