Commit 38867e7e authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent 92ceda65
...@@ -438,6 +438,8 @@ else: ...@@ -438,6 +438,8 @@ else:
next_sequence_prediction=next_sequence_prediction, next_sequence_prediction=next_sequence_prediction,
rule_based_prediction=rule_base_prediction, rule_based_prediction=rule_base_prediction,
loss=loss, loss=loss,
loss_warmup=kl_wu,
warmup_mode=kl_annealing_mode,
X_val=(X_val if X_val.shape != (0,) else None), X_val=(X_val if X_val.shape != (0,) else None),
input_type=input_type, input_type=input_type,
cp=False, cp=False,
......
...@@ -73,6 +73,8 @@ def get_callbacks( ...@@ -73,6 +73,8 @@ def get_callbacks(
next_sequence_prediction: float, next_sequence_prediction: float,
rule_based_prediction: float, rule_based_prediction: float,
loss: str, loss: str,
loss_warmup: int = 0,
warmup_mode: str = "none",
X_val: np.array = None, X_val: np.array = None,
input_type: str = False, input_type: str = False,
cp: bool = False, cp: bool = False,
...@@ -108,6 +110,8 @@ def get_callbacks( ...@@ -108,6 +110,8 @@ def get_callbacks(
("_PhenoPred={}".format(phenotype_prediction) if variational else ""), ("_PhenoPred={}".format(phenotype_prediction) if variational else ""),
("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""), ("_RuleBasedPred={}".format(rule_based_prediction) if variational else ""),
("_loss={}".format(loss) if variational else ""), ("_loss={}".format(loss) if variational else ""),
("loss_warmup={}_".format(loss_warmup)),
("warmup_mode={}_".format(warmup_mode)),
("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""), ("_encoding={}".format(logparam["encoding"]) if logparam is not None else ""),
("_k={}".format(logparam["k"]) if logparam is not None else ""), ("_k={}".format(logparam["k"]) if logparam is not None else ""),
("_latreg={}".format(latreg)), ("_latreg={}".format(latreg)),
......
...@@ -157,7 +157,7 @@ rule train_models: ...@@ -157,7 +157,7 @@ rule train_models:
"--loss {wildcards.loss} " "--loss {wildcards.loss} "
"--kl-annealing-mode {wildcards.warmup_mode} " "--kl-annealing-mode {wildcards.warmup_mode} "
"--kl-warmup {wildcards.warmup} " "--kl-warmup {wildcards.warmup} "
"--mmd-annealing-mode sigmoid " "--mmd-annealing-mode {wildcards.warmup_mode} "
"--mmd-warmup {wildcards.warmup} " "--mmd-warmup {wildcards.warmup} "
"--montecarlo-kl 10 " "--montecarlo-kl 10 "
"--encoding-size {wildcards.encs} " "--encoding-size {wildcards.encs} "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment