Commit bba9f951 authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent 7816a09a
Pipeline #100112 canceled with stages
in 34 seconds
......@@ -884,10 +884,12 @@ class coordinates:
encoding_size: int = 4,
epochs: int = 50,
hparams: dict = None,
kl_annealing_mode: str = "linear",
kl_warmup: int = 0,
log_history: bool = True,
log_hparams: bool = False,
loss: str = "ELBO",
mmd_annealing_mode: str = "linear",
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
......@@ -949,10 +951,12 @@ class coordinates:
encoding_size=encoding_size,
epochs=epochs,
hparams=hparams,
kl_annealing_mode=kl_annealing_mode,
kl_warmup=kl_warmup,
log_history=log_history,
log_hparams=log_hparams,
loss=loss,
mmd_annealing_mode=mmd_annealing_mode,
mmd_warmup=mmd_warmup,
montecarlo_kl=montecarlo_kl,
n_components=n_components,
......
......@@ -415,7 +415,7 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
to the final model loss.
"""
def __init__(self, iters, warm_up_iters, annealing_mode="sigmoid", *args, **kwargs):
def __init__(self, iters, warm_up_iters, annealing_mode, *args, **kwargs):
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
self.is_placeholder = True
self._iters = iters
......@@ -476,7 +476,7 @@ class MMDiscrepancyLayer(Layer):
prior,
iters,
warm_up_iters,
annealing_mode="sigmoid",
annealing_mode,
*args,
**kwargs
):
......
......@@ -98,6 +98,13 @@ parser.add_argument(
type=str,
default="dists",
)
parser.add_argument(
"--kl-annealing-mode",
"-klam",
help="Weight annealing to use for ELBO loss. Can be one of 'linear' and 'sigmoid'",
default="linear",
type=str,
)
parser.add_argument(
"--kl-warmup",
"-klw",
......@@ -136,6 +143,13 @@ parser.add_argument(
default="ELBO+MMD",
type=str,
)
parser.add_argument(
"--mmd-annealing-mode",
"-mmdam",
help="Weight annealing to use for MMD loss. Can be one of 'linear' and 'sigmoid'",
default="linear",
type=str,
)
parser.add_argument(
"--mmd-warmup",
"-mmdw",
......@@ -251,11 +265,13 @@ gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters if args.hyperparameters is not None else {}
input_type = args.input_type
k = args.components
kl_annealing_mode = args.kl_annealing_mode
kl_wu = args.kl_warmup
entropy_knn = args.entropy_knn
entropy_samples = args.entropy_samples
latent_reg = args.latent_reg
loss = args.loss
mmd_annealing_mode = args.mmd_annealing_mode
mmd_wu = args.mmd_warmup
mc_kl = args.montecarlo_kl
neuron_control = args.neuron_control
......@@ -385,10 +401,12 @@ if not tune:
batch_size=batch_size,
encoding_size=encoding_size,
hparams={},
kl_annealing_mode=kl_annealing_mode,
kl_warmup=kl_wu,
log_history=True,
log_hparams=True,
loss=loss,
mmd_annealing_mode=mmd_annealing_mode,
mmd_warmup=mmd_wu,
montecarlo_kl=mc_kl,
n_components=k,
......
......@@ -279,10 +279,12 @@ def autoencoder_fitting(
encoding_size: int,
epochs: int,
hparams: dict,
kl_annealing_mode: str,
kl_warmup: int,
log_history: bool,
log_hparams: bool,
loss: str,
mmd_annealing_mode: str,
mmd_warmup: int,
montecarlo_kl: int,
n_components: int,
......@@ -382,8 +384,10 @@ def autoencoder_fitting(
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_annealing_mode=kl_annealing_mode,
kl_warmup_epochs=kl_warmup,
loss=loss,
mmd_annealing_mode=mmd_annealing_mode,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment