Commit 7b86681d authored by lucas_miranda's avatar lucas_miranda
Browse files

Instanciated a mean and a variance for each component, and the categorical...

Instanciated a mean and a variance for each component, and the categorical prior as a Dense layer with a softmax activation
parent dba020ba
......@@ -344,9 +344,7 @@ class SEQ_2_SEQ_VAE:
return self.input_shape[1:] * huber(x_, x_decoded_mean_)
vae.compile(
loss=huber_loss,
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
)
return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback
......@@ -587,9 +585,7 @@ class SEQ_2_SEQ_VAEP:
return self.input_shape[1:] * huber(x_, x_decoded_mean_)
vaep.compile(
loss=huber_loss,
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
)
return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback
......@@ -625,6 +621,10 @@ class SEQ_2_SEQ_MMVAEP:
self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components
assert (
self.number_of_components > 0
), "The number of components must be an integer greater than zero"
assert (
"ELBO" in self.loss or "MMD" in self.loss
), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
......@@ -684,8 +684,8 @@ class SEQ_2_SEQ_MMVAEP:
Model_D0 = DenseTranspose(
Model_E5, activation="relu", output_dim=self.ENCODING,
)
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2, )
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1, )
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional(
LSTM(
......@@ -719,6 +719,10 @@ class SEQ_2_SEQ_MMVAEP:
encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder)
# Categorical prior on mixture of Gaussians
categories = Dense(self.number_of_components, activation="softmax")
for _ in range(self.number_of_components):
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
......@@ -832,9 +836,7 @@ class SEQ_2_SEQ_MMVAEP:
return self.input_shape[1:] * huber(x_, x_decoded_mean_)
gmvaep.compile(
loss=huber_loss,
optimizer=Adam(lr=self.learn_rate, ),
metrics=["mae"],
loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
)
return encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment