Commit 007d9e44 authored by lucas_miranda's avatar lucas_miranda
Browse files

Updated default GMVAE models

parent 2b84273d
......@@ -231,7 +231,6 @@ class SEQ_2_SEQ_GMVAE:
architecture_hparams: dict = {},
batch_size: int = 256,
compile_model: bool = True,
dense_activation: str = "elu",
entropy_reg_weight: float = 0.0,
huber_delta: float = 1.0,
initialiser_iters: int = int(1e4),
......@@ -246,15 +245,17 @@ class SEQ_2_SEQ_GMVAE:
self.hparams = self.get_hparams(architecture_hparams)
self.batch_size = batch_size
self.CONV_filters = self.hparams["units_conv"]
self.dense_activation = dense_activation
self.LSTM_units_1 = self.hparams["units_lstm"]
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.lstm_unroll = True
self.DENSE_1 = int(self.hparams["units_lstm"] / 2)
self.DENSE_2 = self.hparams["units_dense2"]
self.DROPOUT_RATE = self.hparams["dropout_rate"]
self.ENCODING = self.hparams["encoding"]
self.LSTM_units_1 = self.hparams["units_lstm"]
self.LSTM_units_2 = int(self.hparams["units_lstm"] / 2)
self.clipvalue = self.hparams["clipvalue"]
self.dense_activation = self.hparams["dense_activation"]
self.learn_rate = self.hparams["learning_rate"]
self.optimizer = self.hparams["optimizer"]
self.lstm_unroll = True
self.compile = compile_model
self.delta = huber_delta
self.entropy_reg_weight = entropy_reg_weight
......@@ -316,9 +317,11 @@ class SEQ_2_SEQ_GMVAE:
"units_conv": 256,
"units_lstm": 256,
"units_dense2": 64,
"dropout_rate": 0.25,
"dropout_rate": 0.15,
"encoding": 16,
"learning_rate": 1e-3,
"clipvalue": 0.5,
"dense_activation": "elu",
}
for k, v in params.items():
......@@ -663,7 +666,7 @@ class SEQ_2_SEQ_GMVAE:
if self.compile:
gmvaep.compile(
loss=model_losses,
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
optimizer=Nadam(lr=self.learn_rate, clipvalue=self.clipvalue,),
metrics=model_metrics,
loss_weights=loss_weights,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment