From 1a4e8cdc9911e5eb3c4be37860767946bafac40c Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Thu, 22 Apr 2021 16:03:20 +0200
Subject: [PATCH] Added a Conv1D layer at the end of both decoder and
next_sequence_predictor
---
deepof/models.py | 23 ++++++++++++++++++++---
1 file changed, 20 insertions(+), 3 deletions(-)
diff --git a/deepof/models.py b/deepof/models.py
index b4481348..2756faaf 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -341,7 +341,7 @@ class SEQ_2_SEQ_GMVAE:
"bidirectional_merge": "concat",
"clipvalue": 1.0,
"dense_activation": "relu",
- "dense_layers_per_branch": 1,
+ "dense_layers_per_branch": 3,
"dropout_rate": 0.05,
"learning_rate": 1e-3,
"units_conv": 64,
@@ -362,7 +362,7 @@ class SEQ_2_SEQ_GMVAE:
filters=self.CONV_filters,
kernel_size=5,
strides=1,
- padding="causal",
+ padding="same",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
@@ -418,6 +418,8 @@ class SEQ_2_SEQ_GMVAE:
Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization()
+ Model_B4 = BatchNormalization()
+
seq_D = [
Dense(
self.DENSE_2,
@@ -463,6 +465,15 @@ class SEQ_2_SEQ_GMVAE:
),
merge_mode=self.bidirectional_merge,
)
+ Model_D6 = tf.keras.layers.Conv1D(
+ filters=self.CONV_filters,
+ kernel_size=5,
+ strides=1,
+ padding="same",
+ activation=self.dense_activation,
+ kernel_initializer=he_uniform(),
+ use_bias=True,
+ )
# Predictor layers
Model_P1 = Dense(
@@ -519,11 +530,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
+ Model_B4,
Model_D1,
Model_D2,
Model_D3,
Model_D4,
Model_D5,
+ Model_D6,
Model_P1,
Model_P2,
Model_P3,
@@ -547,11 +560,13 @@ class SEQ_2_SEQ_GMVAE:
Model_B1,
Model_B2,
Model_B3,
+ Model_B4,
Model_D1,
Model_D2,
Model_D3,
Model_D4,
Model_D5,
+ Model_D6,
Model_P1,
Model_P2,
Model_P3,
@@ -577,7 +592,7 @@ class SEQ_2_SEQ_GMVAE:
self.number_of_components,
name="cluster_assignment",
activation="softmax",
- kernel_regularizer=(
+ activity_regularizer=(
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
if self.reg_cat_clusters
else None
@@ -688,6 +703,8 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_B2(generator)
generator = Model_D5(generator)
generator = Model_B3(generator)
+ generator = Model_D6(generator)
+ generator = Model_B4(generator)
x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator)
--
GitLab