Commit 90ba416a authored by lucas_miranda's avatar lucas_miranda
Browse files

Reimplemented KL warmup using optimizer iterators; getting rid of the clumsy callback

parent 547af774
......@@ -221,16 +221,17 @@ class neighbor_cluster_purity(tf.keras.callbacks.Callback):
):
super().__init__()
self.enc = encoding_dim
self.r = (
-0.14220132706202965 * np.log2(validation_data.shape[0])
+ 0.17189696892334544 * self.enc
+ 1.6940295848037952
) # Empirically derived from data. See examples/set_default_entropy_radius.ipynb for details
self.variational = variational
self.validation_data = validation_data
self.samples = samples
self.log_dir = log_dir
self.min_n = min_n
if self.validation_data is not None:
self.r = (
-0.14220132706202965 * np.log2(validation_data.shape[0])
+ 0.17189696892334544 * self.enc
+ 1.6940295848037952
) # Empirically derived from data. See examples/set_default_entropy_radius.ipynb for details
# noinspection PyMethodOverriding,PyTypeChecker
def on_epoch_end(self, epoch, logs=None):
......@@ -438,7 +439,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
self.is_placeholder = True
self._iters = iters
self._warm_up_iters = warm_up_iters
self._regularizer._weight = K.min(self._iters / self._warm_up_iters, 1)
self._regularizer._weight = K.min([self._iters / self._warm_up_iters, 1.0])
def get_config(self): # pragma: no cover
"""Updates Constraint metadata"""
......
......@@ -623,16 +623,19 @@ class SEQ_2_SEQ_GMVAE:
)([z_cat, z_gauss])
# Define and control custom loss functions
kl_warmup_callback = False
if "ELBO" in self.loss:
warm_up_iters = tf.cast(
self.kl_warmup * (input_shape[0] / self.batch_size), tf.int64
)
# noinspection PyCallingNonCallable
z = deepof.model_utils.KLDivergenceLayer(
self.prior,
distribution_b=self.prior,
test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0,
iters=self.optimizer.iterations,
warm_up_iters=self.kl_warmup,
warm_up_iters=warm_up_iters,
)(z)
mmd_warmup_callback = False
......@@ -754,7 +757,6 @@ class SEQ_2_SEQ_GMVAE:
generator,
grouper,
gmvaep,
kl_warmup_callback,
mmd_warmup_callback,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment