Commit 09bd026d authored by lucas_miranda's avatar lucas_miranda
Browse files

Enhanced performance with tf.function decorators

parent 5d154b87
......@@ -166,6 +166,15 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update(
{
"is_placeholder":self.is_placeholder,
}
)
return config
def call(self, distribution_a):
kl_batch = self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a])
......
......@@ -171,7 +171,7 @@ class SEQ_2_SEQ_GMVAE:
number_of_components=1,
predictor=True,
overlap_loss=False,
entropy_reg_weight=1.0,
entropy_reg_weight=0.0,
):
self.input_shape = input_shape
self.batch_size = batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment