Skip to content
Snippets Groups Projects
Commit 9f7abcf1 authored by Jakob Roth's avatar Jakob Roth
Browse files

Merge branch 'fix_nullop_stat_ncg' into 'NIFTy_8'

_static_newton_cg: Fix maxiter==0 handling

Closes #426

See merge request !1009
parents 100c6cb9 88b56d4a
No related branches found
No related tags found
1 merge request!1009_static_newton_cg: Fix maxiter==0 handling
Pipeline #246100 passed
......@@ -327,8 +327,8 @@ def _static_newton_cg(
energy, g = fun_and_grad(pos)
conditional_raise(jnp.isnan(energy), ValueError("energy is Nan"))
val = {
"status": -2,
"iteration": 1,
"status": jnp.where(maxiter == 0, 0, -2),
"iteration": 0,
"pos": pos,
"energy": energy,
"old_energy": old_fval if old_fval is not None else jnp.inf,
......@@ -342,7 +342,7 @@ def _static_newton_cg(
return v["status"] < -1
def single_newton_cg_step(v):
status, i = v["status"], v["iteration"]
status, i = v["status"], v["iteration"] + 1
pos = v["pos"]
energy, g = v["energy"], v["g"]
old_energy = v["old_energy"]
......@@ -419,7 +419,7 @@ def _static_newton_cg(
ret = {
"status": status,
"iteration": i + 1,
"iteration": i,
"pos": pos,
"energy": energy,
"old_energy": old_energy,
......@@ -432,7 +432,9 @@ def _static_newton_cg(
val = while_loop(continue_condition_newton_cg, single_newton_cg_step, val)
conditional_call(
val["status"] > 0, logger.error, PyTreeString(f"{nm}: Iteration Limit Reached!")
(val["status"] > 0) | (maxiter == 0),
logger.error,
PyTreeString(f"{nm}: Iteration Limit Reached!"),
)
return OptimizeResults(
x=val["pos"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment