Commit 64469448 authored by lucas_miranda's avatar lucas_miranda
Browse files

Reimplemented MMD warmup using optimizer iterators; getting rid of the clumsy callback

parent d2c9404c
...@@ -168,7 +168,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -168,7 +168,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
lstm_units_1, lstm_units_1,
) = self.get_hparams(hp) ) = self.get_hparams(hp)
gmvaep, kl_warmup_callback, mmd_warmup_callback = deepof.models.SEQ_2_SEQ_GMVAE( gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams={ architecture_hparams={
"bidirectional_merge": "ave", "bidirectional_merge": "ave",
"clipvalue": clipvalue, "clipvalue": clipvalue,
...@@ -187,7 +187,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel): ...@@ -187,7 +187,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
overlap_loss=self.overlap_loss, overlap_loss=self.overlap_loss,
phenotype_prediction=self.pheno_class, phenotype_prediction=self.pheno_class,
predictor=self.predictor, predictor=self.predictor,
).build(self.input_shape)[3:] ).build(self.input_shape)[-1]
return gmvaep return gmvaep
......
...@@ -445,13 +445,23 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -445,13 +445,23 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
config = super().get_config().copy() config = super().get_config().copy()
config.update({"is_placeholder": self.is_placeholder}) config.update({"is_placeholder": self.is_placeholder})
config.update({"_iters": self._iters})
config.update({"_warm_up_iters": self._warm_up_iters})
return config return config
def call(self, distribution_a): def call(self, distribution_a):
"""Updates Layer's call method""" """Updates Layer's call method"""
self._regularizer._weight = K.min([self._iters / self._warm_up_iters, 1.0]) # Define and update KL weight for warmup
kl_batch = self._regularizer(distribution_a) if self._warm_up_iters > 0:
kl_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
else:
kl_weight = tf.cast(1.0, tf.float32)
kl_batch = kl_weight * self._regularizer(distribution_a)
self.add_loss(kl_batch, inputs=[distribution_a]) self.add_loss(kl_batch, inputs=[distribution_a])
self.add_metric( self.add_metric(
kl_batch, kl_batch,
...@@ -459,7 +469,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -459,7 +469,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
name="kl_divergence", name="kl_divergence",
) )
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.add_metric(self._regularizer._weight, aggregation="mean", name="kl_rate") self.add_metric(kl_weight, aggregation="mean", name="kl_rate")
return distribution_a return distribution_a
...@@ -483,8 +493,8 @@ class MMDiscrepancyLayer(Layer): ...@@ -483,8 +493,8 @@ class MMDiscrepancyLayer(Layer):
config = super().get_config().copy() config = super().get_config().copy()
config.update({"batch_size": self.batch_size}) config.update({"batch_size": self.batch_size})
config.update({"iters": self._iters}) config.update({"_iters": self._iters})
config.update({"warmup_iters": self._warm_up_iters}) config.update({"_warmup_iters": self._warm_up_iters})
config.update({"prior": self.prior}) config.update({"prior": self.prior})
return config return config
...@@ -492,12 +502,17 @@ class MMDiscrepancyLayer(Layer): ...@@ -492,12 +502,17 @@ class MMDiscrepancyLayer(Layer):
"""Updates Layer's call method""" """Updates Layer's call method"""
true_samples = self.prior.sample(self.batch_size) true_samples = self.prior.sample(self.batch_size)
# Define and update MMD weight for warmup
if self._warm_up_iters > 0:
mmd_weight = tf.cast( mmd_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32 K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
) )
else:
mmd_weight = tf.cast(1.0, tf.float32)
# noinspection PyTypeChecker
mmd_batch = mmd_weight * compute_mmd((true_samples, z)) mmd_batch = mmd_weight * compute_mmd((true_samples, z))
self.add_loss(K.mean(mmd_batch), inputs=z) self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd") self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate") self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
......
...@@ -625,7 +625,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -625,7 +625,7 @@ class SEQ_2_SEQ_GMVAE:
# Define and control custom loss functions # Define and control custom loss functions
if "ELBO" in self.loss: if "ELBO" in self.loss:
warm_up_iters = tf.cast( kl_warm_up_iters = tf.cast(
self.kl_warmup * (input_shape[0] // self.batch_size + 1), self.kl_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64, tf.int64,
) )
...@@ -636,12 +636,12 @@ class SEQ_2_SEQ_GMVAE: ...@@ -636,12 +636,12 @@ class SEQ_2_SEQ_GMVAE:
test_points_fn=lambda q: q.sample(self.mc_kl), test_points_fn=lambda q: q.sample(self.mc_kl),
test_points_reduce_axis=0, test_points_reduce_axis=0,
iters=self.optimizer.iterations, iters=self.optimizer.iterations,
warm_up_iters=warm_up_iters, warm_up_iters=kl_warm_up_iters,
)(z) )(z)
if "MMD" in self.loss: if "MMD" in self.loss:
warm_up_iters = tf.cast( mmd_warm_up_iters = tf.cast(
self.mmd_warmup * (input_shape[0] // self.batch_size + 1), self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64, tf.int64,
) )
...@@ -650,7 +650,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -650,7 +650,7 @@ class SEQ_2_SEQ_GMVAE:
batch_size=self.batch_size, batch_size=self.batch_size,
prior=self.prior, prior=self.prior,
iters=self.optimizer.iterations, iters=self.optimizer.iterations,
warm_up_iters=warm_up_iters, warm_up_iters=mmd_warm_up_iters,
)(z) )(z)
# Dummy layer with no parameters, to retrieve the previous tensor # Dummy layer with no parameters, to retrieve the previous tensor
...@@ -767,4 +767,3 @@ class SEQ_2_SEQ_GMVAE: ...@@ -767,4 +767,3 @@ class SEQ_2_SEQ_GMVAE:
# - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs) # - Check usefulness of stateful sequential layers! (stateful=True in the LSTMs)
# - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :) # - Investigate full covariance matrix approximation for the latent space! (details on tfp course) :)
# - Explore expanding the event dims of the final reconstruction layer # - Explore expanding the event dims of the final reconstruction layer
# - Gaussian Mixture as output layer? One component per bodypart (makes sense?)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment