Commit d2c9404c authored by lucas_miranda's avatar lucas_miranda
Browse files

Reimplemented MMD warmup using optimizer iterators; getting rid of the clumsy callback

parent 24b9c3ea
......@@ -470,19 +470,21 @@ class MMDiscrepancyLayer(Layer):
to the final model loss.
"""
def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
def __init__(self, batch_size, prior, iters, warm_up_iters, *args, **kwargs):
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
self.is_placeholder = True
self.batch_size = batch_size
self.beta = beta
self.prior = prior
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
self._iters = iters
self._warm_up_iters = warm_up_iters
def get_config(self): # pragma: no cover
"""Updates Constraint metadata"""
config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"beta": self.beta})
config.update({"iters": self._iters})
config.update({"warmup_iters": self._warm_up_iters})
config.update({"prior": self.prior})
return config
......@@ -490,11 +492,15 @@ class MMDiscrepancyLayer(Layer):
"""Updates Layer's call method"""
true_samples = self.prior.sample(self.batch_size)
mmd_weight = tf.cast(
K.min([self._iters / self._warm_up_iters, 1.0]), tf.float32
)
# noinspection PyTypeChecker
mmd_batch = self.beta * compute_mmd((true_samples, z))
mmd_batch = mmd_weight * compute_mmd((true_samples, z))
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
self.add_metric(mmd_weight, aggregation="mean", name="mmd_rate")
return z
......
......@@ -639,20 +639,18 @@ class SEQ_2_SEQ_GMVAE:
warm_up_iters=warm_up_iters,
)(z)
mmd_warmup_callback = False
if "MMD" in self.loss:
mmd_beta = deepof.model_utils.K.variable(1.0, name="mmd_beta")
mmd_beta._trainable = False
if self.mmd_warmup:
mmd_warmup_callback = LambdaCallback(
on_epoch_begin=lambda epoch, logs: deepof.model_utils.K.set_value(
mmd_beta, deepof.model_utils.K.min([epoch / self.mmd_warmup, 1])
)
)
warm_up_iters = tf.cast(
self.mmd_warmup * (input_shape[0] // self.batch_size + 1),
tf.int64,
)
z = deepof.model_utils.MMDiscrepancyLayer(
batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
batch_size=self.batch_size,
prior=self.prior,
iters=self.optimizer.iterations,
warm_up_iters=warm_up_iters,
)(z)
# Dummy layer with no parameters, to retrieve the previous tensor
......@@ -758,7 +756,6 @@ class SEQ_2_SEQ_GMVAE:
generator,
grouper,
gmvaep,
mmd_warmup_callback,
)
@prior.setter
......
......@@ -327,13 +327,7 @@ def autoencoder_fitting(
return_list = (encoder, decoder, ae)
else:
(
encoder,
generator,
grouper,
ae,
mmd_warmup_callback,
) = deepof.models.SEQ_2_SEQ_GMVAE(
(encoder, generator, grouper, ae,) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size,
compile_model=True,
......@@ -349,9 +343,7 @@ def autoencoder_fitting(
predictor=predictor,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
).build(
X_train.shape
)
).build(X_train.shape)
return_list = (encoder, generator, grouper, ae)
if pretrained:
......@@ -394,10 +386,6 @@ def autoencoder_fitting(
),
]
if "MMD" in loss and mmd_warmup > 0:
# noinspection PyUnboundLocalVariable
callbacks_.append(mmd_warmup_callback)
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment