Commit 7b86681d authored by lucas_miranda's avatar lucas_miranda
Browse files

Instanciated a mean and a variance for each component, and the categorical...

Instanciated a mean and a variance for each component, and the categorical prior as a Dense layer with a softmax activation
parent dba020ba
......@@ -344,9 +344,7 @@ class SEQ_2_SEQ_VAE:
return self.input_shape[1:] * huber(x_, x_decoded_mean_)
vae.compile(
loss=huber_loss,
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
)
return encoder, generator, vae, kl_warmup_callback, mmd_warmup_callback
......@@ -587,9 +585,7 @@ class SEQ_2_SEQ_VAEP:
return self.input_shape[1:] * huber(x_, x_decoded_mean_)
vaep.compile(
loss=huber_loss,
optimizer=Adam(lr=self.learn_rate,),
metrics=["mae"],
loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
)
return encoder, generator, vaep, kl_warmup_callback, mmd_warmup_callback
......@@ -597,19 +593,19 @@ class SEQ_2_SEQ_VAEP:
class SEQ_2_SEQ_MMVAEP:
def __init__(
self,
input_shape,
CONV_filters=256,
LSTM_units_1=256,
LSTM_units_2=64,
DENSE_2=64,
DROPOUT_RATE=0.25,
ENCODING=32,
learn_rate=1e-3,
loss="ELBO+MMD",
kl_warmup_epochs=0,
mmd_warmup_epochs=0,
number_of_components=1,
self,
input_shape,
CONV_filters=256,
LSTM_units_1=256,
LSTM_units_2=64,
DENSE_2=64,
DROPOUT_RATE=0.25,
ENCODING=32,
learn_rate=1e-3,
loss="ELBO+MMD",
kl_warmup_epochs=0,
mmd_warmup_epochs=0,
number_of_components=1,
):
self.input_shape = input_shape
self.CONV_filters = CONV_filters
......@@ -626,7 +622,11 @@ class SEQ_2_SEQ_MMVAEP:
self.number_of_components = number_of_components
assert (
"ELBO" in self.loss or "MMD" in self.loss
self.number_of_components > 0
), "The number of components must be an integer greater than zero"
assert (
"ELBO" in self.loss or "MMD" in self.loss
), "loss must be one of ELBO, MMD or ELBO+MMD (default)"
def build(self):
......@@ -684,8 +684,8 @@ class SEQ_2_SEQ_MMVAEP:
Model_D0 = DenseTranspose(
Model_E5, activation="relu", output_dim=self.ENCODING,
)
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2, )
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1, )
Model_D1 = DenseTranspose(Model_E4, activation="relu", output_dim=self.DENSE_2,)
Model_D2 = DenseTranspose(Model_E3, activation="relu", output_dim=self.DENSE_1,)
Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional(
LSTM(
......@@ -719,8 +719,12 @@ class SEQ_2_SEQ_MMVAEP:
encoder = BatchNormalization()(encoder)
encoder = Model_E5(encoder)
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
# Categorical prior on mixture of Gaussians
categories = Dense(self.number_of_components, activation="softmax")
for _ in range(self.number_of_components):
z_mean = Dense(self.ENCODING)(encoder)
z_log_sigma = Dense(self.ENCODING)(encoder)
# Define and control custom loss functions
kl_warmup_callback = False
......@@ -832,9 +836,7 @@ class SEQ_2_SEQ_MMVAEP:
return self.input_shape[1:] * huber(x_, x_decoded_mean_)
gmvaep.compile(
loss=huber_loss,
optimizer=Adam(lr=self.learn_rate, ),
metrics=["mae"],
loss=huber_loss, optimizer=Adam(lr=self.learn_rate,), metrics=["mae"],
)
return encoder, generator, gmvaep, kl_warmup_callback, mmd_warmup_callback
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment