Commit eff66fa2 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored VAE model

parent 536408f1
......@@ -15,13 +15,13 @@ class SEQ_2_SEQ_AE:
def __init__(
self,
input_shape,
CONV_filters,
LSTM_units_1,
LSTM_units_2,
DENSE_2,
DROPOUT_RATE,
ENCODING,
learn_rate,
CONV_filters=256,
LSTM_units_1=256,
LSTM_units_2=64,
DENSE_2=64,
DROPOUT_RATE=0.25,
ENCODING=32,
learn_rate=1e-3,
):
self.input_shape = input_shape
self.CONV_filters = CONV_filters
......@@ -91,7 +91,7 @@ class SEQ_2_SEQ_AE:
)
# Define and instanciate encoder
encoder = Sequential(name="DLC_encoder")
encoder = Sequential(name="SEQ_2_SEQ_Encoder")
encoder.add(Model_E0)
encoder.add(Model_E1)
encoder.add(Model_E2)
......@@ -101,7 +101,7 @@ class SEQ_2_SEQ_AE:
encoder.add(Model_E5)
# Define and instanciate decoder
decoder = Sequential(name="DLC_Decoder")
decoder = Sequential(name="SEQ_2_SEQ_Decoder")
decoder.add(
DenseTranspose(
Model_E5, activation="relu", input_shape=(self.ENCODING,), output_dim=64
......@@ -114,7 +114,7 @@ class SEQ_2_SEQ_AE:
decoder.add(Model_D5)
decoder.add(TimeDistributed(Dense(self.input_shape[2])))
model = Sequential([encoder, decoder], name="DLC_Autoencoder")
model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
model.compile(
loss=Huber(reduction="sum", delta=100.0),
......@@ -122,20 +122,20 @@ class SEQ_2_SEQ_AE:
metrics=["mae"],
)
return model
return encoder, decoder, model
class SEQ_2_SEQ_VAE:
def __init__(
self,
input_shape,
CONV_filters,
LSTM_units_1,
LSTM_units_2,
DENSE_2,
DROPOUT_RATE,
ENCODING,
learn_rate=1e-5,
CONV_filters=256,
LSTM_units_1=256,
LSTM_units_2=64,
DENSE_2=64,
DROPOUT_RATE=0.25,
ENCODING=32,
learn_rate=1e-3,
loss="ELBO+MMD",
):
self.input_shape = input_shape
......@@ -188,6 +188,11 @@ class SEQ_2_SEQ_VAE:
)
# Decoder layers
Model_D0 = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)
Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional(
LSTM(
self.LSTM_units_1,
......@@ -226,36 +231,21 @@ class SEQ_2_SEQ_VAE:
if "MMD" in self.loss:
z = MMDiscrepancyLayer()(z)
# Define and instanciate decoder
decoder = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)(z)
decoder = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)(
decoder
)
decoder = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)(
decoder
)
decoder = RepeatVector(self.input_shape[1])(decoder)
decoder = Model_D4(decoder)
decoder = Model_D5(decoder)
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(decoder)
#generator
_decoder_input = Input(shape=(self.ENCODING,))
_decoder = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)(z)
_decoder = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)(
decoder
)
_decoder = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)(
decoder
)
_decoder = RepeatVector(self.input_shape[1])(decoder)
_decoder = Model_D4(decoder)
_decoder = Model_D5(decoder)
_x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(decoder)
generator = Model(_decoder_input, _x_decoded_mean)
# Define and instanciate generator
generator = Model_D0(z)
generator = Model_D1(generator)
generator = Model_D2(generator)
generator = Model_D3(generator)
generator = Model_D4(generator)
generator = Model_D5(generator)
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
# end-to-end autoencoder
vae = Model(x, x_decoded_mean)
encoder = Model(x, z_mean, name="SEQ_2_SEQ_VEncoder")
vae = Model(x, x_decoded_mean, name="SEQ_2_SEQ_VAE")
#g = Input(shape=self.ENCODING)
#generator = Model(g, x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
def huber_loss(x, x_decoded_mean):
huber_loss = Huber(reduction="sum", delta=100.0)
......@@ -263,14 +253,12 @@ class SEQ_2_SEQ_VAE:
vae.compile(
loss=huber_loss,
optimizer=Adam(
lr=self.learn_rate,
),
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
experimental_run_tf_function=False,
)
return encoder, generator, vae
return encoder, vae
class SEQ_2_SEQ_MVAE:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment