Commit 38867e7e authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent 92ceda65
......@@ -438,6 +438,8 @@ else:
next_sequence_prediction=next_sequence_prediction,
rule_based_prediction=rule_base_prediction,
loss=loss,
loss_warmup=kl_wu,
warmup_mode=kl_annealing_mode,
X_val=(X_val if X_val.shape != (0,) else None),
input_type=input_type,
cp=False,
......
......@@ -73,6 +73,8 @@ def get_callbacks(
next_sequence_prediction: float,
rule_based_prediction: float,
loss: str,
loss_warmup: int = 0,
warmup_mode: str = "none",
X_val: np.array = None,
input_type: str = False,
cp: bool = False,
......@@ -108,6 +110,8 @@ def get_callbacks(
("_PhenoPred={}".format(phenotype_prediction) if variational else ""),
("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""),
("_loss={}".format(loss) if variational else ""),
("loss_warmup={}_".format(loss_warmup)),
("warmup_mode={}_".format(warmup_mode)),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
("_k={}".format(logparam["k"]) if logparam is not None else ""),
("_latreg={}".format(latreg)),
......
......@@ -157,7 +157,7 @@ rule train_models:
"--loss {wildcards.loss} "
"--kl-annealing-mode {wildcards.warmup_mode} "
"--kl-warmup {wildcards.warmup} "
"--mmd-annealing-mode sigmoid "
"--mmd-annealing-mode {wildcards.warmup_mode} "
"--mmd-warmup {wildcards.warmup} "
"--montecarlo-kl 10 "
"--encoding-size {wildcards.encs} "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment