Commit 8d314040 authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent bba9f951
Pipeline #100113 failed with stages
in 17 minutes and 10 seconds
......@@ -442,7 +442,9 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
if self._annealing_mode == "sigmoid":
kl_weight = 1.0 / (1.0 + tf.exp(-kl_weight))
kl_weight = tf.math.sigmoid(
(2 * kl_weight - 1) / (kl_weight - kl_weight ** 2)
)
else:
raise NotImplementedError(
"annealing_mode must be one of 'linear' and 'sigmoid'"
......@@ -471,14 +473,7 @@ class MMDiscrepancyLayer(Layer):
"""
def __init__(
self,
batch_size,
prior,
iters,
warm_up_iters,
annealing_mode,
*args,
**kwargs
self, batch_size, prior, iters, warm_up_iters, annealing_mode, *args, **kwargs
):
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
self.is_placeholder = True
......@@ -511,7 +506,9 @@ class MMDiscrepancyLayer(Layer):
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
if self._annealing_mode == "sigmoid":
mmd_weight = 1.0 / (1.0 + tf.exp(-mmd_weight))
mmd_weight = tf.math.sigmoid(
(2 * mmd_weight - 1) / (mmd_weight - mmd_weight ** 2)
)
else:
raise NotImplementedError(
"annealing_mode must be one of 'linear' and 'sigmoid'"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment