Commit 09bd026d authored by lucas_miranda's avatar lucas_miranda
Browse files

Enhanced performance with tf.function decorators

parent 5d154b87
...@@ -166,6 +166,15 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -166,6 +166,15 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
self.is_placeholder = True self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs) super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
"is_placeholder":self.is_placeholder,
}
)
return config
def call(self, distribution_a): def call(self, distribution_a):
kl_batch = self._regularizer(distribution_a) kl_batch = self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a]) self.add_loss(kl_batch, inputs=[distribution_a])
......
...@@ -171,7 +171,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -171,7 +171,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components=1, number_of_components=1,
predictor=True, predictor=True,
overlap_loss=False, overlap_loss=False,
entropy_reg_weight=1.0, entropy_reg_weight=0.0,
): ):
self.input_shape = input_shape self.input_shape = input_shape
self.batch_size = batch_size self.batch_size = batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment