Commit 2bdf2679 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated default GMVAE models

parent 007d9e44
Pipeline #87824 failed with stage
in 15 minutes and 37 seconds
......@@ -106,15 +106,20 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
"""Retrieve hyperparameters to tune"""
# Architectural hyperparameters
clipvalue = hp.Float(
"clipvalue",
min_value=0.0,
max_value=1.0,
default=0.5,
sampling="Linear"
)
conv_filters = hp.Int(
"units_conv", min_value=128, max_value=160, step=16, default=128,
)
lstm_units_1 = hp.Int(
"units_lstm", min_value=300, max_value=350, step=10, default=320,
)
dense_2 = hp.Int(
"units_dense2", min_value=120, max_value=180, step=10, default=150,
)
dense_activation = hp.Choice(["elu", "relu"])
dropout_rate = hp.Float(
"dropout_rate",
min_value=0.0,
......@@ -122,25 +127,21 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
default=0.0,
sampling="linear",
)
encoding = (
16 # hp.Int("encoding", min_value=20, max_value=30, step=5, default=25, )
encoding = 16
k = self.number_of_components
lstm_units_1 = hp.Int(
"units_lstm", min_value=300, max_value=350, step=10, default=320,
)
k = 5 # hp.Int(
# "n_components",
# min_value=self.number_of_components - 5,
# max_value=self.number_of_components + 5,
# step=5,
# default=self.number_of_components,
# # sampling="linear",
# )
return (
clipvalue,
conv_filters,
lstm_units_1,
dense_2,
dense_activation,
dropout_rate,
encoding,
k,
lstm_units_1,
)
def build(self, hp):
......@@ -148,21 +149,25 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
# Hyperparameters to tune
(
clipvalue,
conv_filters,
lstm_units_1,
dense_2,
dense_activation,
dropout_rate,
encoding,
k,
lstm_units_1,
) = self.get_hparams(hp)
gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams={
"units_conv": conv_filters,
"units_lstm": lstm_units_1,
"units_dense_2": dense_2,
"clipvalue": clipvalue,
"dense_activation": dense_activation,
"dropout_rate": dropout_rate,
"encoding": encoding,
"units_conv": conv_filters,
"units_dense_2": dense_2,
"units_lstm": lstm_units_1,
},
entropy_reg_weight=self.entropy_reg_weight,
huber_delta=self.huber_delta,
......
......@@ -233,7 +233,7 @@ class SEQ_2_SEQ_GMVAE:
compile_model: bool = True,
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
initialiser_iters: int = int(1e4),
initialiser_iters: int = int(1),
kl_warmup_epochs: int = 0,
loss: str = "ELBO+MMD",
mmd_warmup_epochs: int = 0,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment