From 1a4e8cdc9911e5eb3c4be37860767946bafac40c Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Thu, 22 Apr 2021 16:03:20 +0200 Subject: [PATCH] Added a Conv1D layer at the end of both decoder and next_sequence_predictor --- deepof/models.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/deepof/models.py b/deepof/models.py index b4481348..2756faaf 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -341,7 +341,7 @@ class SEQ_2_SEQ_GMVAE: "bidirectional_merge": "concat", "clipvalue": 1.0, "dense_activation": "relu", - "dense_layers_per_branch": 1, + "dense_layers_per_branch": 3, "dropout_rate": 0.05, "learning_rate": 1e-3, "units_conv": 64, @@ -362,7 +362,7 @@ class SEQ_2_SEQ_GMVAE: filters=self.CONV_filters, kernel_size=5, strides=1, - padding="causal", + padding="same", activation=self.dense_activation, kernel_initializer=he_uniform(), use_bias=True, @@ -418,6 +418,8 @@ class SEQ_2_SEQ_GMVAE: Model_B1 = BatchNormalization() Model_B2 = BatchNormalization() Model_B3 = BatchNormalization() + Model_B4 = BatchNormalization() + seq_D = [ Dense( self.DENSE_2, @@ -463,6 +465,15 @@ class SEQ_2_SEQ_GMVAE: ), merge_mode=self.bidirectional_merge, ) + Model_D6 = tf.keras.layers.Conv1D( + filters=self.CONV_filters, + kernel_size=5, + strides=1, + padding="same", + activation=self.dense_activation, + kernel_initializer=he_uniform(), + use_bias=True, + ) # Predictor layers Model_P1 = Dense( @@ -519,11 +530,13 @@ class SEQ_2_SEQ_GMVAE: Model_B1, Model_B2, Model_B3, + Model_B4, Model_D1, Model_D2, Model_D3, Model_D4, Model_D5, + Model_D6, Model_P1, Model_P2, Model_P3, @@ -547,11 +560,13 @@ class SEQ_2_SEQ_GMVAE: Model_B1, Model_B2, Model_B3, + Model_B4, Model_D1, Model_D2, Model_D3, Model_D4, Model_D5, + Model_D6, Model_P1, Model_P2, Model_P3, @@ -577,7 +592,7 @@ class SEQ_2_SEQ_GMVAE: self.number_of_components, name="cluster_assignment", activation="softmax", - kernel_regularizer=( + activity_regularizer=( tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) if self.reg_cat_clusters else None @@ -688,6 +703,8 @@ class SEQ_2_SEQ_GMVAE: generator = Model_B2(generator) generator = Model_D5(generator) generator = Model_B3(generator) + generator = Model_D6(generator) + generator = Model_B4(generator) x_decoded_mean = Dense( tfpl.IndependentNormal.params_size(input_shape[2:]) // 2 )(generator) -- GitLab