Commit aed60bd3 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added He initialization to Conv1D and Dense layers in models.py

parent 02ca2b4c
......@@ -2,6 +2,7 @@
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.initializers import he_uniform
from tensorflow.keras.layers import BatchNormalization, Bidirectional, Dense
from tensorflow.keras.layers import Dropout, Lambda, LSTM
from tensorflow.keras.layers import RepeatVector, TimeDistributed
......@@ -41,6 +42,7 @@ class SEQ_2_SEQ_AE:
strides=1,
padding="causal",
activation="relu",
kernel_initializer=he_uniform(),
)
Model_E1 = Bidirectional(
LSTM(
......@@ -59,22 +61,44 @@ class SEQ_2_SEQ_AE:
)
)
Model_E3 = Dense(
self.DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0)
self.DENSE_1,
activation="relu",
kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
)
Model_E4 = Dense(
self.DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0)
self.DENSE_2,
activation="relu",
kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
)
Model_E5 = Dense(
self.ENCODING,
activation="relu",
kernel_constraint=UnitNorm(axis=1),
activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
kernel_initializer=he_uniform(),
)
# Decoder layers
Model_D0 = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)
Model_D0 = DenseTranspose(
Model_E5,
activation="relu",
output_dim=self.ENCODING,
kernel_initializer=he_uniform(),
)
Model_D1 = DenseTranspose(
Model_E4,
activation="relu",
output_dim=self.DENSE_2,
kernel_initializer=he_uniform(),
)
Model_D2 = DenseTranspose(
Model_E3,
activation="relu",
output_dim=self.DENSE_1,
kernel_initializer=he_uniform(),
)
Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional(
LSTM(
......@@ -167,6 +191,7 @@ class SEQ_2_SEQ_VAE:
strides=1,
padding="causal",
activation="relu",
kernel_initializer=he_uniform(),
)
Model_E1 = Bidirectional(
LSTM(
......@@ -185,23 +210,45 @@ class SEQ_2_SEQ_VAE:
)
)
Model_E3 = Dense(
self.DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0)
self.DENSE_1,
activation="relu",
kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
)
Model_E4 = Dense(
self.DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0)
self.DENSE_2,
activation="relu",
kernel_constraint=UnitNorm(axis=0),
kernel_initializer=he_uniform(),
)
Model_E5 = Dense(
self.ENCODING,
activation="relu",
kernel_constraint=UnitNorm(axis=1),
activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
kernel_initializer=he_uniform(),
)
# Decoder layers
Model_D0 = DenseTranspose(Model_E5, activation="relu", output_dim=self.ENCODING)
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2)
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1)
Model_D0 = DenseTranspose(
Model_E5,
activation="relu",
output_dim=self.ENCODING,
kernel_initializer=he_uniform(),
)
Model_D1 = DenseTranspose(
Model_E4,
activation="relu",
output_dim=self.DENSE_2,
kernel_initializer=he_uniform(),
)
Model_D2 = DenseTranspose(
Model_E3,
activation="relu",
output_dim=self.DENSE_1,
kernel_initializer=he_uniform(),
)
Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional(
LSTM(
......@@ -305,8 +352,12 @@ class SEQ_2_SEQ_MMVAE:
# - Tied/Untied weights (done!)
# - orthogonal/non-orthogonal weights (done!)
# - Unit Norm constraint (done!)
# - add batch normalization
# - add batch normalization (done!)
# - add He initialization
# - remove sigmoid activation from last layer
# - add another dropout
# - try orthonotmal initialization in encoding layer
# - try reverse sequence as output!
# TODO next:
# - VAE loss function (though this should be analysed later on taking the encodings into account)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment