Commit 3c5ef64e authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent 63c23405
......@@ -437,14 +437,16 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
# Define and update KL weight for warmup
if self._warm_up_iters > 0:
if self._annealing_mode == "linear":
if self._annealing_mode in ["linear", "sigmoid"]:
kl_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
elif self._annealing_mode == "sigmoid":
kl_weight = 0
if self._annealing_mode == "sigmoid":
kl_weight = 1.0 / (1.0 + tf.exp(-kl_weight))
else:
raise NotImplementedError("annealing_mode must be one of 'linear' and 'sigmoid'")
raise NotImplementedError(
"annealing_mode must be one of 'linear' and 'sigmoid'"
)
else:
kl_weight = tf.cast(1.0, tf.float32)
......@@ -468,13 +470,23 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def __init__(self, batch_size, prior, iters, warm_up_iters, *args, **kwargs):
def __init__(
self,
batch_size,
prior,
iters,
warm_up_iters,
annealing_mode="sigmoid",
*args,
**kwargs
):
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
self.is_placeholder = True
self.batch_size = batch_size
self.prior = prior
self._iters = iters
self._warm_up_iters = warm_up_iters
self._annealing_mode = annealing_mode
def get_config(self): # pragma: no cover
"""Updates Constraint metadata"""
......@@ -484,6 +496,7 @@ class MMDiscrepancyLayer(Layer):
config.update({"_iters": self._iters})
config.update({"_warmup_iters": self._warm_up_iters})
config.update({"prior": self.prior})
config.update({"_annealing_mode": self._annealing_mode})
return config
def call(self, z, **kwargs):
......@@ -493,9 +506,16 @@ class MMDiscrepancyLayer(Layer):
# Define and update MMD weight for warmup
if self._warm_up_iters > 0:
mmd_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
if self._annealing_mode in ["linear", "sigmoid"]:
mmd_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
if self._annealing_mode == "sigmoid":
mmd_weight = 1.0 / (1.0 + tf.exp(-mmd_weight))
else:
raise NotImplementedError(
"annealing_mode must be one of 'linear' and 'sigmoid'"
)
else:
mmd_weight = tf.cast(1.0, tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment