Commit 206a376d authored by lucas_miranda's avatar lucas_miranda
Browse files

Minimise entropy to see if overal confidence increases in a reproducible way

parent d96c8c4c
...@@ -169,9 +169,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -169,9 +169,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def get_config(self): def get_config(self):
config = super().get_config().copy() config = super().get_config().copy()
config.update( config.update(
{ {"is_placeholder": self.is_placeholder,}
"is_placeholder":self.is_placeholder,
}
) )
return config return config
...@@ -357,7 +355,7 @@ class Entropy_regulariser(Layer): ...@@ -357,7 +355,7 @@ class Entropy_regulariser(Layer):
# axis=1 increases the entropy of a cluster across instances # axis=1 increases the entropy of a cluster across instances
# axis=0 increases the entropy of the assignment for a given instance # axis=0 increases the entropy of the assignment for a given instance
entropy = - K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1) entropy = -K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1)
# Adds metric that monitors dead neurons in the latent space # Adds metric that monitors dead neurons in the latent space
self.add_metric(entropy, aggregation="mean", name="-weight_entropy") self.add_metric(entropy, aggregation="mean", name="-weight_entropy")
......
...@@ -171,7 +171,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -171,7 +171,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components=1, number_of_components=1,
predictor=True, predictor=True,
overlap_loss=False, overlap_loss=False,
entropy_reg_weight=0.25, entropy_reg_weight=1.0,
): ):
self.input_shape = input_shape self.input_shape = input_shape
self.batch_size = batch_size self.batch_size = batch_size
...@@ -315,7 +315,9 @@ class SEQ_2_SEQ_GMVAE: ...@@ -315,7 +315,9 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
encoding_shuffle = MCDropout(self.DROPOUT_RATE)(encoder) encoding_shuffle = MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(self.number_of_components, activation="softmax",)(encoding_shuffle) z_cat = Dense(self.number_of_components, activation="softmax",)(
encoding_shuffle
)
z_cat = Entropy_regulariser(self.entropy_reg_weight)(z_cat) z_cat = Entropy_regulariser(self.entropy_reg_weight)(z_cat)
z_gauss = Dense( z_gauss = Dense(
tfpl.IndependentNormal.params_size( tfpl.IndependentNormal.params_size(
...@@ -468,7 +470,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -468,7 +470,7 @@ class SEQ_2_SEQ_GMVAE:
gmvaep.compile( gmvaep.compile(
loss=huber_loss, loss=huber_loss,
optimizer=Nadam(lr=self.learn_rate), optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"], metrics=["mae"],
loss_weights=([1, self.predictor] if self.predictor > 0 else [1]), loss_weights=([1, self.predictor] if self.predictor > 0 else [1]),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment