Commit 7b86681d authored by lucas_miranda's avatar lucas_miranda
Browse files

Instanciated a mean and a variance for each component, and the categorical...

Instanciated a mean and a variance for each component, and the categorical prior as a Dense layer with a softmax activation
parent dba020ba
...@@ -344,9 +344,7 @@ class SEQ_2_SEQ_VAE: ...@@ -344,9 +344,7 @@ class SEQ_2_SEQ_VAE:
return self.input_shape[1:] * huber(x_, x_decoded_mean_) return self.input_shape[1:] * huber(x_, x_decoded_mean_)
vae.compile( vae.compile(
loss=huber_loss, loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
) )
return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback
...@@ -587,9 +585,7 @@ class SEQ_2_SEQ_VAEP: ...@@ -587,9 +585,7 @@ class SEQ_2_SEQ_VAEP:
return self.input_shape[1:] * huber(x_, x_decoded_mean_) return self.input_shape[1:] * huber(x_, x_decoded_mean_)
vaep.compile( vaep.compile(
loss=huber_loss, loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
) )
return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback
...@@ -597,19 +593,19 @@ class SEQ_2_SEQ_VAEP: ...@@ -597,19 +593,19 @@ class SEQ_2_SEQ_VAEP:
class SEQ_2_SEQ_MMVAEP: class SEQ_2_SEQ_MMVAEP:
def __init__( def __init__(
self, self,
input_shape, input_shape,
CONV_filters=256, CONV_filters=256,
LSTM_units_1=256, LSTM_units_1=256,
LSTM_units_2=64, LSTM_units_2=64,
DENSE_2=64, DENSE_2=64,
DROPOUT_RATE=0.25, DROPOUT_RATE=0.25,
ENCODING=32, ENCODING=32,
learn_rate=1e-3, learn_rate=1e-3,
loss="ELBO+MMD", loss="ELBO+MMD",
kl_warmup_epochs=0, kl_warmup_epochs=0,
mmd_warmup_epochs=0, mmd_warmup_epochs=0,
number_of_components=1, number_of_components=1,
): ):
self.input_shape = input_shape self.input_shape = input_shape
self.CONV_filters = CONV_filters self.CONV_filters = CONV_filters
...@@ -626,7 +622,11 @@ class SEQ_2_SEQ_MMVAEP: ...@@ -626,7 +622,11 @@ class SEQ_2_SEQ_MMVAEP:
self.number_of_components = number_of_components self.number_of_components = number_of_components
assert ( assert (
"ELBO" in self.loss or "MMD" in self.loss self.number_of_components > 0
), "The number of components must be an integer greater than zero"
assert (
"ELBO" in self.loss or "MMD" in self.loss
), "loss must be one of ELBO, MMD or ELBO+MMD (default)" ), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
def build(self): def build(self):
...@@ -684,8 +684,8 @@ class SEQ_2_SEQ_MMVAEP: ...@@ -684,8 +684,8 @@ class SEQ_2_SEQ_MMVAEP:
Model_D0 = DenseTranspose( Model_D0 = DenseTranspose(
Model_E5, activation="relu", output_dim=self.ENCODING, Model_E5, activation="relu", output_dim=self.ENCODING,
) )
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2, ) Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1, ) Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
Model_D3 = RepeatVector(self.input_shape[1]) Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional( Model_D4 = Bidirectional(
LSTM( LSTM(
...@@ -719,8 +719,12 @@ class SEQ_2_SEQ_MMVAEP: ...@@ -719,8 +719,12 @@ class SEQ_2_SEQ_MMVAEP:
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder) encoder = Model_E5(encoder)
z_mean = Dense(self.ENCODING)(encoder) # Categorical prior on mixture of Gaussians
z_log_sigma = Dense(self.ENCODING)(encoder) categories = Dense(self.number_of_components, activation="softmax")
for _ in range(self.number_of_components):
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
# Define and control custom loss functions # Define and control custom loss functions
kl_warmup_callback = False kl_warmup_callback = False
...@@ -832,9 +836,7 @@ class SEQ_2_SEQ_MMVAEP: ...@@ -832,9 +836,7 @@ class SEQ_2_SEQ_MMVAEP:
return self.input_shape[1:] * huber(x_, x_decoded_mean_) return self.input_shape[1:] * huber(x_, x_decoded_mean_)
gmvaep.compile( gmvaep.compile(
loss=huber_loss, loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
optimizer=Adam(lr=self.learn_rate, ),
metrics=["mae"],
) )
return encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback return encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment