Commit 3c5ef64e authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent 63c23405
...@@ -437,14 +437,16 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -437,14 +437,16 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
# Define and update KL weight for warmup # Define and update KL weight for warmup
if self._warm_up_iters > 0: if self._warm_up_iters > 0:
if self._annealing_mode == "linear": if self._annealing_mode in ["linear", "sigmoid"]:
kl_weight = tf.cast( kl_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32 K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
) )
elif self._annealing_mode == "sigmoid": if self._annealing_mode == "sigmoid":
kl_weight = 0 kl_weight = 1.0 / (1.0 + tf.exp(-kl_weight))
else: else:
raise NotImplementedError("annealing_mode must be one of 'linear' and 'sigmoid'") raise NotImplementedError(
"annealing_mode must be one of 'linear' and 'sigmoid'"
)
else: else:
kl_weight = tf.cast(1.0, tf.float32) kl_weight = tf.cast(1.0, tf.float32)
...@@ -468,13 +470,23 @@ class MMDiscrepancyLayer(Layer): ...@@ -468,13 +470,23 @@ class MMDiscrepancyLayer(Layer):
to the final model loss. to the final model loss.
""" """
def __init__(self, batch_size, prior, iters, warm_up_iters, *args, **kwargs): def __init__(
self,
batch_size,
prior,
iters,
warm_up_iters,
annealing_mode="sigmoid",
*args,
**kwargs
):
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs) super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
self.is_placeholder = True self.is_placeholder = True
self.batch_size = batch_size self.batch_size = batch_size
self.prior = prior self.prior = prior
self._iters = iters self._iters = iters
self._warm_up_iters = warm_up_iters self._warm_up_iters = warm_up_iters
self._annealing_mode = annealing_mode
def get_config(self): # pragma: no cover def get_config(self): # pragma: no cover
"""Updates Constraint metadata""" """Updates Constraint metadata"""
...@@ -484,6 +496,7 @@ class MMDiscrepancyLayer(Layer): ...@@ -484,6 +496,7 @@ class MMDiscrepancyLayer(Layer):
config.update({"_iters": self._iters}) config.update({"_iters": self._iters})
config.update({"_warmup_iters": self._warm_up_iters}) config.update({"_warmup_iters": self._warm_up_iters})
config.update({"prior": self.prior}) config.update({"prior": self.prior})
config.update({"_annealing_mode": self._annealing_mode})
return config return config
def call(self, z, **kwargs): def call(self, z, **kwargs):
...@@ -493,9 +506,16 @@ class MMDiscrepancyLayer(Layer): ...@@ -493,9 +506,16 @@ class MMDiscrepancyLayer(Layer):
# Define and update MMD weight for warmup # Define and update MMD weight for warmup
if self._warm_up_iters > 0: if self._warm_up_iters > 0:
mmd_weight = tf.cast( if self._annealing_mode in ["linear", "sigmoid"]:
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32 mmd_weight = tf.cast(
) K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
if self._annealing_mode == "sigmoid":
mmd_weight = 1.0 / (1.0 + tf.exp(-mmd_weight))
else:
raise NotImplementedError(
"annealing_mode must be one of 'linear' and 'sigmoid'"
)
else: else:
mmd_weight = tf.cast(1.0, tf.float32) mmd_weight = tf.cast(1.0, tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment