Commit 1a4e8cdc authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a Conv1D layer at the end of both decoder and next_sequence_predictor

parent c25f8d92
......@@ -341,7 +341,7 @@ class SEQ_2_SEQ_GMVAE:
"bidirectional_merge": "concat",
"clipvalue": 1.0,
"dense_activation": "relu",
"dense_layers_per_branch": 1,
"dense_layers_per_branch": 3,
"dropout_rate": 0.05,
"learning_rate": 1e-3,
"units_conv": 64,
......@@ -362,7 +362,7 @@ class SEQ_2_SEQ_GMVAE:
filters=self.CONV_filters,
kernel_size=5,
strides=1,
padding="causal",
padding="same",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
......@@ -418,6 +418,8 @@ class SEQ_2_SEQ_GMVAE:
Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
seq_D = [
Dense(
self.DENSE_2,
......@@ -463,6 +465,15 @@ class SEQ_2_SEQ_GMVAE:
),
merge_mode=self.bidirectional_merge,
)
Model_D6 = tf.keras.layers.Conv1D(
filters=self.CONV_filters,
kernel_size=5,
strides=1,
padding="same",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
# Predictor layers
Model_P1 = Dense(
......@@ -519,11 +530,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
Model_D4,
Model_D5,
Model_D6,
Model_P1,
Model_P2,
Model_P3,
......@@ -547,11 +560,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
Model_D4,
Model_D5,
Model_D6,
Model_P1,
Model_P2,
Model_P3,
......@@ -577,7 +592,7 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components,
name="cluster_assignment",
activation="softmax",
kernel_regularizer=(
activity_regularizer=(
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
if self.reg_cat_clusters
else None
......@@ -688,6 +703,8 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_B2(generator)
generator = Model_D5(generator)
generator = Model_B3(generator)
generator = Model_D6(generator)
generator = Model_B4(generator)
x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment