Commit 998f5107 authored by lucas_miranda's avatar lucas_miranda
Browse files

Reimplemented KL warmup using optimizer iterators; getting rid of the clumsy callback

parent 90ba416a
......@@ -439,7 +439,6 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
self.is_placeholder = True
self._iters = iters
self._warm_up_iters = warm_up_iters
self._regularizer._weight = K.min([self._iters / self._warm_up_iters, 1.0])
def get_config(self): # pragma: no cover
"""Updates Constraint metadata"""
......@@ -451,6 +450,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
def call(self, distribution_a):
"""Updates Layer's call method"""
self._regularizer._weight = K.min([self._iters / self._warm_up_iters, 1.0])
kl_batch = self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a])
self.add_metric(
......
......@@ -626,7 +626,8 @@ class SEQ_2_SEQ_GMVAE:
if "ELBO" in self.loss:
warm_up_iters = tf.cast(
self.kl_warmup * (input_shape[0] / self.batch_size), tf.int64
self.kl_warmup * (input_shape[0] / self.batch_size),
tf.int64,
)
# noinspection PyCallingNonCallable
......
......@@ -332,7 +332,6 @@ def autoencoder_fitting(
generator,
grouper,
ae,
kl_warmup_callback,
mmd_warmup_callback,
) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
......@@ -395,9 +394,6 @@ def autoencoder_fitting(
),
]
if "ELBO" in loss and kl_warmup > 0:
# noinspection PyUnboundLocalVariable
callbacks_.append(kl_warmup_callback)
if "MMD" in loss and mmd_warmup > 0:
# noinspection PyUnboundLocalVariable
callbacks_.append(mmd_warmup_callback)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment