Commit ce80ff28 authored by lucas_miranda's avatar lucas_miranda
Browse files

Instanciated a mean and a variance for each component, and the categorical...

Instanciated a mean and a variance for each component, and the categorical prior as a Dense layer with a softmax activation
parent 226b9801
......@@ -125,14 +125,14 @@ class SEQ_2_SEQ_AE:
# Define and instantiate decoder
decoder = Sequential(name="SEQ_2_SEQ_Decoder")
decoder.add(Model_D0)
encoder.add(BatchNormalization())
decoder.add(BatchNormalization())
decoder.add(Model_D1)
encoder.add(BatchNormalization())
decoder.add(BatchNormalization())
decoder.add(Model_D2)
encoder.add(BatchNormalization())
decoder.add(BatchNormalization())
decoder.add(Model_D3)
decoder.add(Model_D4)
encoder.add(BatchNormalization())
decoder.add(BatchNormalization())
decoder.add(Model_D5)
decoder.add(TimeDistributed(Dense(self.input_shape[2])))
......@@ -722,10 +722,15 @@ class SEQ_2_SEQ_MMVAEP:
# Categorical prior on mixture of Gaussians
categories = Dense(self.number_of_components, activation="softmax")
# Define mean and log_sigma as lists of vectors with an item per prior component
z_mean = []
z_log_sigma = []
for i in range(self.number_of_components):
z_mean = Dense(self.ENCODING, name="{}_gaussian_mean".format(i+1))(encoder)
z_log_sigma = Dense(self.ENCODING, name="{}_gaussian_sigma".format(i+1))(
encoder
z_mean.append(
Dense(self.ENCODING, name="{}_gaussian_mean".format(i + 1))(encoder)
)
z_log_sigma.append(
Dense(self.ENCODING, name="{}_gaussian_sigma".format(i + 1))(encoder)
)
# Define and control custom loss functions
......@@ -741,7 +746,9 @@ class SEQ_2_SEQ_MMVAEP:
)
)
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])
z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)(
[z_mean[0], z_log_sigma[0]]
)
z = Lambda(sampling)([z_mean, z_log_sigma])
......@@ -845,9 +852,11 @@ class SEQ_2_SEQ_MMVAEP:
# TODO:
# - Models as promer tf.keras.Model subclasses
# - Gaussian Mixture + Categorical priors -> Deep Clustering
#
# TODO (in the non-immediate future):
# - MCMC sampling (n>1)
# - free bits paper
# - Attention mechanism for encoder / decoder (does it make sense?)
# - Transformer encoder/decoder (does it make sense?)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment