diff --git a/src/re/optimize.py b/src/re/optimize.py index 87e0d9d026ce1065931544de1fbcad7cda9f59a6..cb6eb422a72a5504500e847e4d2dfe02c6e73a86 100644 --- a/src/re/optimize.py +++ b/src/re/optimize.py @@ -327,8 +327,8 @@ def _static_newton_cg( energy, g = fun_and_grad(pos) conditional_raise(jnp.isnan(energy), ValueError("energy is Nan")) val = { - "status": -2, - "iteration": 1, + "status": jnp.where(maxiter == 0, 0, -2), + "iteration": 0, "pos": pos, "energy": energy, "old_energy": old_fval if old_fval is not None else jnp.inf, @@ -342,7 +342,7 @@ def _static_newton_cg( return v["status"] < -1 def single_newton_cg_step(v): - status, i = v["status"], v["iteration"] + status, i = v["status"], v["iteration"] + 1 pos = v["pos"] energy, g = v["energy"], v["g"] old_energy = v["old_energy"] @@ -419,7 +419,7 @@ def _static_newton_cg( ret = { "status": status, - "iteration": i + 1, + "iteration": i, "pos": pos, "energy": energy, "old_energy": old_energy, @@ -432,7 +432,9 @@ def _static_newton_cg( val = while_loop(continue_condition_newton_cg, single_newton_cg_step, val) conditional_call( - val["status"] > 0, logger.error, PyTreeString(f"{nm}: Iteration Limit Reached!") + (val["status"] > 0) | (maxiter == 0), + logger.error, + PyTreeString(f"{nm}: Iteration Limit Reached!"), ) return OptimizeResults( x=val["pos"],