Commit 75d88e72 authored by lucas_miranda's avatar lucas_miranda
Browse files

Drafted skeleton of SEQ2SEQ_VAEP (Variational AutoEncoding Predictor) in models.py

parent 32a2aadb
......@@ -222,6 +222,7 @@ class SEQ_2_SEQ_VAE:
Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
Model_B5 = BatchNormalization()
Model_B6 = BatchNormalization()
Model_D0 = DenseTranspose(
Model_E5, activation="relu", output_dim=self.ENCODING,
)
......@@ -283,6 +284,7 @@ class SEQ_2_SEQ_VAE:
generator = Model_D4(generator)
generator = Model_B5(generator)
generator = Model_D5(generator)
generator = Model_B6(generator)
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
# end-to-end autoencoder
......@@ -302,6 +304,7 @@ class SEQ_2_SEQ_VAE:
_generator = Model_D4(_generator)
_generator = Model_B5(_generator)
_generator = Model_D5(_generator)
_generator = Model_B6(_generator)
_x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
......@@ -395,6 +398,7 @@ class SEQ_2_SEQ_VAEP:
Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
Model_B5 = BatchNormalization()
Model_B6 = BatchNormalization()
Model_D0 = DenseTranspose(
Model_E5, activation="relu", output_dim=self.ENCODING,
)
......@@ -456,6 +460,7 @@ class SEQ_2_SEQ_VAEP:
generator = Model_D4(generator)
generator = Model_B5(generator)
generator = Model_D5(generator)
generator = Model_B6(generator)
x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(generator)
# Define and instanciate predictor
......@@ -484,6 +489,7 @@ class SEQ_2_SEQ_VAEP:
kernel_constraint=UnitNorm(axis=1),
)
)(predictor)
predictor = BatchNormalization()(predictor)
x_predicted_mean = TimeDistributed(Dense(self.input_shape[2]))(predictor)
# end-to-end autoencoder
......@@ -505,6 +511,7 @@ class SEQ_2_SEQ_VAEP:
_generator = Model_D4(_generator)
_generator = Model_B5(_generator)
_generator = Model_D5(_generator)
_generator = Model_B6(_generator)
_x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment