Commit e7514b68 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added a Conv1D layer at the end of both decoder and next_sequence_predictor

parent 1a4e8cdc
Pipeline #100109 passed with stages
in 21 minutes and 55 seconds
......@@ -415,11 +415,6 @@ class SEQ_2_SEQ_GMVAE:
Model_E4.append(BatchNormalization())
# Decoder layers
Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization()
seq_D = [
Dense(
self.DENSE_2,
......@@ -506,6 +501,15 @@ class SEQ_2_SEQ_GMVAE:
),
merge_mode=self.bidirectional_merge,
)
Model_P4 = tf.keras.layers.Conv1D(
filters=self.CONV_filters,
kernel_size=5,
strides=1,
padding="same",
activation=self.dense_activation,
kernel_initializer=he_uniform(),
use_bias=True,
)
# Phenotype classification layer
Model_PC1 = Dense(
......@@ -527,10 +531,6 @@ class SEQ_2_SEQ_GMVAE:
Model_E2,
Model_E3,
Model_E4,
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
......@@ -540,6 +540,7 @@ class SEQ_2_SEQ_GMVAE:
Model_P1,
Model_P2,
Model_P3,
Model_P4,
Model_PC1,
Model_RC1,
)
......@@ -557,10 +558,6 @@ class SEQ_2_SEQ_GMVAE:
Model_E2,
Model_E3,
Model_E4,
Model_B1,
Model_B2,
Model_B3,
Model_B4,
Model_D1,
Model_D2,
Model_D3,
......@@ -570,6 +567,7 @@ class SEQ_2_SEQ_GMVAE:
Model_P1,
Model_P2,
Model_P3,
Model_P4,
Model_PC1,
Model_RC1,
) = self.get_layers(input_shape)
......@@ -697,14 +695,14 @@ class SEQ_2_SEQ_GMVAE:
g = Input(shape=self.ENCODING)
generator = Sequential(Model_D1)(g)
generator = Model_D2(generator)
generator = Model_B1(generator)
generator = BatchNormalization()(generator)
generator = Model_D3(generator)
generator = Model_D4(generator)
generator = Model_B2(generator)
generator = BatchNormalization()(generator)
generator = Model_D5(generator)
generator = Model_B3(generator)
generator = BatchNormalization()(generator)
generator = Model_D6(generator)
generator = Model_B4(generator)
generator = BatchNormalization()(generator)
x_decoded_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(generator)
......@@ -750,6 +748,7 @@ class SEQ_2_SEQ_GMVAE:
predictor = BatchNormalization()(predictor)
predictor = Model_P3(predictor)
predictor = BatchNormalization()(predictor)
predictor = Model_P4(predictor)
x_predicted_mean = Dense(
tfpl.IndependentNormal.params_size(input_shape[2:]) // 2
)(predictor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment