Commit 7816a09a authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent 3c5ef64e
......@@ -247,8 +247,10 @@ class SEQ_2_SEQ_GMVAE:
batch_size: int = 256,
compile_model: bool = True,
encoding: int = 6,
kl_annealing_mode: str = "sigmoid",
kl_warmup_epochs: int = 20,
loss: str = "ELBO",
mmd_annealing_mode: str = "sigmoid",
mmd_warmup_epochs: int = 20,
montecarlo_kl: int = 1,
neuron_control: bool = False,
......@@ -277,9 +279,11 @@ class SEQ_2_SEQ_GMVAE:
self.learn_rate = self.hparams["learning_rate"]
self.lstm_unroll = True
self.compile = compile_model
self.kl_annealing_mode = kl_annealing_mode
self.kl_warmup = kl_warmup_epochs
self.loss = loss
self.mc_kl = montecarlo_kl
self.mmd_annealing_mode = mmd_annealing_mode
self.mmd_warmup = mmd_warmup_epochs
self.neuron_control = neuron_control
self.number_of_components = number_of_components
......@@ -673,6 +677,7 @@ class SEQ_2_SEQ_GMVAE:
test_points_reduce_axis=0,
iters=self.optimizer.iterations,
warm_up_iters=kl_warm_up_iters,
annealing_mode=self.kl_annealing_mode,
)(z)
if "MMD" in self.loss:
......@@ -686,6 +691,7 @@ class SEQ_2_SEQ_GMVAE:
prior=self.prior,
iters=self.optimizer.iterations,
warm_up_iters=mmd_warm_up_iters,
annealing_mode=self.mmd_annealing_mode,
)(z)
# Dummy layer with no parameters, to retrieve the previous tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment