Commit 1a4e8cdc authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a Conv1D layer at the end of both decoder and next_sequence_predictor

parent c25f8d92
...@@ -341,7 +341,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -341,7 +341,7 @@ class SEQ_2_SEQ_GMVAE:
"bidirectional_merge": "concat", "bidirectional_merge": "concat",
"clipvalue": 1.0, "clipvalue": 1.0,
"dense_activation": "relu", "dense_activation": "relu",
"dense_layers_per_branch": 1, "dense_layers_per_branch": 3,
"dropout_rate": 0.05, "dropout_rate": 0.05,
"learning_rate": 1e-3, "learning_rate": 1e-3,
"units_conv": 64, "units_conv": 64,
...@@ -362,7 +362,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -362,7 +362,7 @@ class SEQ_2_SEQ_GMVAE:
filters=self.CONV_filters, filters=self.CONV_filters,
kernel_size=5, kernel_size=5,
strides=1, strides=1,
padding="causal", padding="same",
activation=self.dense_activation, activation=self.dense_activation,
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
...@@ -418,6 +418,8 @@ class SEQ_2_SEQ_GMVAE: ...@@ -418,6 +418,8 @@ class SEQ_2_SEQ_GMVAE:
Model_B1 = BatchNormalization() Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization() Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization() Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
seq_D = [ seq_D = [
Dense( Dense(
self.DENSE_2, self.DENSE_2,
...@@ -463,6 +465,15 @@ class SEQ_2_SEQ_GMVAE: ...@@ -463,6 +465,15 @@ class SEQ_2_SEQ_GMVAE:
), ),
merge_mode=self.bidirectional_merge, merge_mode=self.bidirectional_merge,
) )
Model_D6 = tf.keras.layers.Conv1D(
filters=self.CONV_filters,
kernel_size=5,
strides=1,
padding="same",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
# Predictor layers # Predictor layers
Model_P1 = Dense( Model_P1 = Dense(
...@@ -519,11 +530,13 @@ class SEQ_2_SEQ_GMVAE: ...@@ -519,11 +530,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B1, Model_B1,
Model_B2, Model_B2,
Model_B3, Model_B3,
Model_B4,
Model_D1, Model_D1,
Model_D2, Model_D2,
Model_D3, Model_D3,
Model_D4, Model_D4,
Model_D5, Model_D5,
Model_D6,
Model_P1, Model_P1,
Model_P2, Model_P2,
Model_P3, Model_P3,
...@@ -547,11 +560,13 @@ class SEQ_2_SEQ_GMVAE: ...@@ -547,11 +560,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B1, Model_B1,
Model_B2, Model_B2,
Model_B3, Model_B3,
Model_B4,
Model_D1, Model_D1,
Model_D2, Model_D2,
Model_D3, Model_D3,
Model_D4, Model_D4,
Model_D5, Model_D5,
Model_D6,
Model_P1, Model_P1,
Model_P2, Model_P2,
Model_P3, Model_P3,
...@@ -577,7 +592,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -577,7 +592,7 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components, self.number_of_components,
name="cluster_assignment", name="cluster_assignment",
activation="softmax", activation="softmax",
kernel_regularizer=( activity_regularizer=(
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
if self.reg_cat_clusters if self.reg_cat_clusters
else None else None
...@@ -688,6 +703,8 @@ class SEQ_2_SEQ_GMVAE: ...@@ -688,6 +703,8 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_B2(generator) generator = Model_B2(generator)
generator = Model_D5(generator) generator = Model_D5(generator)
generator = Model_B3(generator) generator = Model_B3(generator)
generator = Model_D6(generator)
generator = Model_B4(generator)
x_decoded_mean = Dense( x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2 tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator) )(generator)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment