Commit 63c23405 authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent e7514b68
...@@ -415,11 +415,12 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -415,11 +415,12 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
to the final model loss. to the final model loss.
""" """
def __init__(self, iters, warm_up_iters, *args, **kwargs): def __init__(self, iters, warm_up_iters, annealing_mode="sigmoid", *args, **kwargs):
super(KLDivergenceLayer, self).__init__(*args, **kwargs) super(KLDivergenceLayer, self).__init__(*args, **kwargs)
self.is_placeholder = True self.is_placeholder = True
self._iters = iters self._iters = iters
self._warm_up_iters = warm_up_iters self._warm_up_iters = warm_up_iters
self._annealing_mode = annealing_mode
def get_config(self): # pragma: no cover def get_config(self): # pragma: no cover
"""Updates Constraint metadata""" """Updates Constraint metadata"""
...@@ -428,6 +429,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -428,6 +429,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
config.update({"is_placeholder": self.is_placeholder}) config.update({"is_placeholder": self.is_placeholder})
config.update({"_iters": self._iters}) config.update({"_iters": self._iters})
config.update({"_warm_up_iters": self._warm_up_iters}) config.update({"_warm_up_iters": self._warm_up_iters})
config.update({"_annealing_mode": self._annealing_mode})
return config return config
def call(self, distribution_a): def call(self, distribution_a):
...@@ -435,9 +437,14 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -435,9 +437,14 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
# Define and update KL weight for warmup # Define and update KL weight for warmup
if self._warm_up_iters > 0: if self._warm_up_iters > 0:
kl_weight = tf.cast( if self._annealing_mode == "linear":
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32 kl_weight = tf.cast(
) K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
elif self._annealing_mode == "sigmoid":
kl_weight = 0
else:
raise NotImplementedError("annealing_mode must be one of 'linear' and 'sigmoid'")
else: else:
kl_weight = tf.cast(1.0, tf.float32) kl_weight = tf.cast(1.0, tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment