Commit d6e82f87 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

streamlining

parent 86a89ed1
Pipeline #26723 passed with stage
in 9 minutes and 47 seconds
......@@ -24,14 +24,26 @@ from ..logger import logger
from .iteration_controller import IterationController
def _toNdarray(fld):
return fld.to_global_data().reshape(-1)
def _toFlatNdarray(fld):
return fld.val.flatten()
def _toField(arr, dom):
return Field.from_global_data(dom, arr.reshape(dom.shape))
class _MinHelper(object):
def __init__(self, energy):
self._energy = energy
self._domain = energy.position.domain
def _update(self, x):
pos = Field(self._domain, x.reshape(self._domain.shape))
if (pos.val != self._energy.position.val).any():
pos = _toField(x, self._domain)
if (pos != self._energy.position).any():
self._energy = self._energy.at(pos.locked_copy())
def fun(self, x):
......@@ -40,13 +52,12 @@ class _MinHelper(object):
def jac(self, x):
self._update(x)
return self._energy.gradient.val.flatten()
return _toFlatNdarray(self._energy.gradient)
def hessp(self, x, p):
self._update(x)
vec = Field(self._domain, p.reshape(self._domain.shape))
res = self._energy.curvature(vec)
return res.val.flatten()
res = self._energy.curvature(_toField(p, self._domain))
return _toFlatNdarray(res)
class ScipyMinimizer(Minimizer):
......@@ -129,22 +140,16 @@ class ScipyCG(Minimizer):
if not isinstance(energy, QuadraticEnergy):
raise ValueError("need a quadratic energy for CG")
def toNdarray(fld):
return fld.to_global_data().reshape(-1)
def toField(arr, dom):
return Field.from_global_data(dom, arr.reshape(dom.shape))
class mymatvec(object):
def __init__(self, op):
self._op = op
def __call__(self, inp):
return toNdarray(self._op(toField(inp, self._op.domain)))
return _toNdarray(self._op(_toField(inp, self._op.domain)))
op = energy._A
b = toNdarray(energy._b)
sx = toNdarray(energy.position)
b = _toNdarray(energy._b)
sx = _toNdarray(energy.position)
sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
matvec=mymatvec(op))
prec_op = None
......@@ -155,4 +160,4 @@ class ScipyCG(Minimizer):
maxiter=self._maxiter)
stat = (IterationController.CONVERGED if stat >= 0 else
IterationController.ERROR)
return energy.at(toField(res, op.domain)), stat
return energy.at(_toField(res, op.domain)), stat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment