diff --git a/nifty4/minimization/scipy_minimizer.py b/nifty4/minimization/scipy_minimizer.py index fef9e7d2933d66668fd9c0bcec40c440901119b0..7356c3e1cf86728559129dbba7c179caad189891 100644 --- a/nifty4/minimization/scipy_minimizer.py +++ b/nifty4/minimization/scipy_minimizer.py @@ -24,14 +24,26 @@ from ..logger import logger from .iteration_controller import IterationController +def _toNdarray(fld): + return fld.to_global_data().reshape(-1) + + +def _toFlatNdarray(fld): + return fld.val.flatten() + + +def _toField(arr, dom): + return Field.from_global_data(dom, arr.reshape(dom.shape)) + + class _MinHelper(object): def __init__(self, energy): self._energy = energy self._domain = energy.position.domain def _update(self, x): - pos = Field(self._domain, x.reshape(self._domain.shape)) - if (pos.val != self._energy.position.val).any(): + pos = _toField(x, self._domain) + if (pos != self._energy.position).any(): self._energy = self._energy.at(pos.locked_copy()) def fun(self, x): @@ -40,13 +52,12 @@ class _MinHelper(object): def jac(self, x): self._update(x) - return self._energy.gradient.val.flatten() + return _toFlatNdarray(self._energy.gradient) def hessp(self, x, p): self._update(x) - vec = Field(self._domain, p.reshape(self._domain.shape)) - res = self._energy.curvature(vec) - return res.val.flatten() + res = self._energy.curvature(_toField(p, self._domain)) + return _toFlatNdarray(res) class ScipyMinimizer(Minimizer): @@ -129,22 +140,16 @@ class ScipyCG(Minimizer): if not isinstance(energy, QuadraticEnergy): raise ValueError("need a quadratic energy for CG") - def toNdarray(fld): - return fld.to_global_data().reshape(-1) - - def toField(arr, dom): - return Field.from_global_data(dom, arr.reshape(dom.shape)) - class mymatvec(object): def __init__(self, op): self._op = op def __call__(self, inp): - return toNdarray(self._op(toField(inp, self._op.domain))) + return _toNdarray(self._op(_toField(inp, self._op.domain))) op = energy._A - b = toNdarray(energy._b) - sx = toNdarray(energy.position) + b = _toNdarray(energy._b) + sx = _toNdarray(energy.position) sci_op = scipy_linop(shape=(op.domain.size, op.target.size), matvec=mymatvec(op)) prec_op = None @@ -155,4 +160,4 @@ class ScipyCG(Minimizer): maxiter=self._maxiter) stat = (IterationController.CONVERGED if stat >= 0 else IterationController.ERROR) - return energy.at(toField(res, op.domain)), stat + return energy.at(_toField(res, op.domain)), stat