Skip to content
Snippets Groups Projects
Commit 41670663 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Added a hyperparameter controlling the number of layers near the latent space

parent 13c2a34e
Branches
Tags
No related merge requests found
Pipeline #87830 failed
...@@ -107,11 +107,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -107,11 +107,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
# Architectural hyperparameters # Architectural hyperparameters
clipvalue = hp.Float( clipvalue = hp.Float(
"clipvalue", "clipvalue", min_value=0.0, max_value=1.0, default=0.5, sampling="Linear"
min_value=0.0,
max_value=1.0,
default=0.5,
sampling="Linear"
) )
conv_filters = hp.Int( conv_filters = hp.Int(
"units_conv", min_value=128, max_value=160, step=16, default=128, "units_conv", min_value=128, max_value=160, step=16, default=128,
...@@ -120,6 +116,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -120,6 +116,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
"units_dense2", min_value=120, max_value=180, step=10, default=150, "units_dense2", min_value=120, max_value=180, step=10, default=150,
) )
dense_activation = hp.Choice("dense_activation", values=["elu", "relu"]) dense_activation = hp.Choice("dense_activation", values=["elu", "relu"])
dense_layers_per_branch = hp.Int("dense_layers_per_branch", min_value=1, max_value=3, default=1)
dropout_rate = hp.Float( dropout_rate = hp.Float(
"dropout_rate", "dropout_rate",
min_value=0.0, min_value=0.0,
...@@ -138,6 +135,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -138,6 +135,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
conv_filters, conv_filters,
dense_2, dense_2,
dense_activation, dense_activation,
dense_layers_per_branch,
dropout_rate, dropout_rate,
encoding, encoding,
k, k,
...@@ -153,6 +151,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -153,6 +151,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
conv_filters, conv_filters,
dense_2, dense_2,
dense_activation, dense_activation,
dense_layers_per_branch,
dropout_rate, dropout_rate,
encoding, encoding,
k, k,
...@@ -164,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -164,6 +163,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
"clipvalue": clipvalue, "clipvalue": clipvalue,
"dense_activation": dense_activation, "dense_activation": dense_activation,
"dropout_rate": dropout_rate, "dropout_rate": dropout_rate,
"dense_layers_per_branch": dense_layers_per_branch,
"encoding": encoding, "encoding": encoding,
"units_conv": conv_filters, "units_conv": conv_filters,
"units_dense_2": dense_2, "units_dense_2": dense_2,
......
...@@ -253,6 +253,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -253,6 +253,7 @@ class SEQ_2_SEQ_GMVAE:
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2) self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.clipvalue = self.hparams["clipvalue"] self.clipvalue = self.hparams["clipvalue"]
self.dense_activation = self.hparams["dense_activation"] self.dense_activation = self.hparams["dense_activation"]
self.dense_layers_per_branch = self.hparams["dense_layers_per_branch"]
self.learn_rate = self.hparams["learning_rate"] self.learn_rate = self.hparams["learning_rate"]
self.lstm_unroll = True self.lstm_unroll = True
self.compile = compile_model self.compile = compile_model
...@@ -313,14 +314,15 @@ class SEQ_2_SEQ_GMVAE: ...@@ -313,14 +314,15 @@ class SEQ_2_SEQ_GMVAE:
"""Sets the default parameters for the model. Overwritable with a dictionary""" """Sets the default parameters for the model. Overwritable with a dictionary"""
defaults = { defaults = {
"units_conv": 256, "clipvalue": 0.5,
"units_lstm": 256, "dense_activation": "elu",
"units_dense2": 64, "dense_layers_per_branch": 1,
"dropout_rate": 0.15, "dropout_rate": 0.15,
"encoding": 16, "encoding": 16,
"learning_rate": 1e-3, "learning_rate": 1e-3,
"clipvalue": 0.5, "units_conv": 256,
"dense_activation": "elu", "units_dense2": 64,
"units_lstm": 256,
} }
for k, v in params.items(): for k, v in params.items():
...@@ -370,25 +372,32 @@ class SEQ_2_SEQ_GMVAE: ...@@ -370,25 +372,32 @@ class SEQ_2_SEQ_GMVAE:
kernel_initializer=he_uniform(), kernel_initializer=he_uniform(),
use_bias=True, use_bias=True,
) )
Model_E4 = Dense(
self.DENSE_2, Model_E4 = [
activation=self.dense_activation, Dense(
# kernel_constraint=UnitNorm(axis=0), self.DENSE_2,
kernel_initializer=he_uniform(), activation=self.dense_activation,
use_bias=True, # kernel_constraint=UnitNorm(axis=0),
) kernel_initializer=he_uniform(),
use_bias=True,
)
for _ in self.dense_layers_per_branch
]
# Decoder layers # Decoder layers
Model_B1 = BatchNormalization() Model_B1 = BatchNormalization()
Model_B2 = BatchNormalization() Model_B2 = BatchNormalization()
Model_B3 = BatchNormalization() Model_B3 = BatchNormalization()
Model_B4 = BatchNormalization() Model_B4 = BatchNormalization()
Model_D1 = Dense( Model_D1 = [
self.DENSE_2, Dense(
activation=self.dense_activation, self.DENSE_2,
kernel_initializer=he_uniform(), activation=self.dense_activation,
use_bias=True, kernel_initializer=he_uniform(),
) use_bias=True,
)
for _ in self.dense_layers_per_branch
]
Model_D2 = Dense( Model_D2 = Dense(
self.DENSE_1, self.DENSE_1,
activation=self.dense_activation, activation=self.dense_activation,
...@@ -516,7 +525,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -516,7 +525,7 @@ class SEQ_2_SEQ_GMVAE:
encoder = Model_E3(encoder) encoder = Model_E3(encoder)
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder) encoder = Dropout(self.DROPOUT_RATE)(encoder)
encoder = Model_E4(encoder) encoder = Sequential(Model_E4)(encoder)
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder) # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
...@@ -546,7 +555,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -546,7 +555,7 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent( tfd.Independent(
tfd.Normal( tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k], loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING :, k]), scale=softplus(gauss[1][..., self.ENCODING:, k]),
), ),
reinterpreted_batch_ndims=1, reinterpreted_batch_ndims=1,
) )
...@@ -588,7 +597,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -588,7 +597,7 @@ class SEQ_2_SEQ_GMVAE:
)(z) )(z)
# Define and instantiate generator # Define and instantiate generator
generator = Model_D1(z) generator = Sequential(Model_D1)(z)
generator = Model_B1(generator) generator = Model_B1(generator)
generator = Model_D2(generator) generator = Model_D2(generator)
generator = Model_B2(generator) generator = Model_B2(generator)
...@@ -650,7 +659,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -650,7 +659,7 @@ class SEQ_2_SEQ_GMVAE:
# Build generator as a separate entity # Build generator as a separate entity
g = Input(shape=self.ENCODING) g = Input(shape=self.ENCODING)
_generator = Model_D1(g) _generator = Sequential(Model_D1)(g)
_generator = Model_B1(_generator) _generator = Model_B1(_generator)
_generator = Model_D2(_generator) _generator = Model_D2(_generator)
_generator = Model_B2(_generator) _generator = Model_B2(_generator)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment