Commit 64469448 authored by lucas_miranda's avatar lucas_miranda
Browse files

Reimplemented MMD warmup using optimizer iterators; getting rid of the clumsy callback

parent d2c9404c
......@@ -168,7 +168,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
lstm_units_1,
) = self.get_hparams(hp)
gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE(
gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams={
"bidirectional_merge": "ave",
"clipvalue": clipvalue,
......@@ -187,7 +187,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
overlap_loss=self.overlap_loss,
phenotype_prediction=self.pheno_class,
predictor=self.predictor,
).build(self.input_shape)[3:]
).build(self.input_shape)[-1]
return gmvaep
......
......@@ -445,13 +445,23 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
config = super().get_config().copy()
config.update({"is_placeholder": self.is_placeholder})
config.update({"_iters": self._iters})
config.update({"_warm_up_iters": self._warm_up_iters})
return config
def call(self, distribution_a):
"""Updates Layer's call method"""
self._regularizer._weight = K.min([self._iters / self._warm_up_iters, 1.0])
kl_batch = self._regularizer(distribution_a)
# Define and update KL weight for warmup
if self._warm_up_iters > 0:
kl_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
else:
kl_weight = tf.cast(1.0, tf.float32)
kl_batch = kl_weight * self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a])
self.add_metric(
kl_batch,
......@@ -459,7 +469,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
name="kl_divergence",
)
# noinspection PyProtectedMember
self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate")
self.add_metric(kl_weight, aggregation="mean", name="kl_rate")
return distribution_a
......@@ -483,8 +493,8 @@ class MMDiscrepancyLayer(Layer):
config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"iters": self._iters})
config.update({"warmup_iters": self._warm_up_iters})
config.update({"_iters": self._iters})
config.update({"_warmup_iters": self._warm_up_iters})
config.update({"prior": self.prior})
return config
......@@ -492,12 +502,17 @@ class MMDiscrepancyLayer(Layer):
"""Updates Layer's call method"""
true_samples = self.prior.sample(self.batch_size)
mmd_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
# noinspection PyTypeChecker
# Define and update MMD weight for warmup
if self._warm_up_iters > 0:
mmd_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
else:
mmd_weight = tf.cast(1.0, tf.float32)
mmd_batch = mmd_weight * compute_mmd((true_samples, z))
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
......
......@@ -625,7 +625,7 @@ class SEQ_2_SEQ_GMVAE:
# Define and control custom loss functions
if "ELBO" in self.loss:
warm_up_iters = tf.cast(
kl_warm_up_iters = tf.cast(
self.kl_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64,
)
......@@ -636,12 +636,12 @@ class SEQ_2_SEQ_GMVAE:
test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0,
iters=self.optimizer.iterations,
warm_up_iters=warm_up_iters,
warm_up_iters=kl_warm_up_iters,
)(z)
if "MMD" in self.loss:
warm_up_iters = tf.cast(
mmd_warm_up_iters = tf.cast(
self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64,
)
......@@ -650,7 +650,7 @@ class SEQ_2_SEQ_GMVAE:
batch_size=self.batch_size,
prior=self.prior,
iters=self.optimizer.iterations,
warm_up_iters=warm_up_iters,
warm_up_iters=mmd_warm_up_iters,
)(z)
# Dummy layer with no parameters, to retrieve the previous tensor
......@@ -767,4 +767,3 @@ class SEQ_2_SEQ_GMVAE:
# - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
# - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
# - Explore expanding the event dims of the final reconstruction layer
# - Gaussian Mixture as output layer? One component per bodypart (makes sense?)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment