Commit efa20a5f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try to support mixed-type multifields; no tests yet

parent 0b2d39bc
......@@ -27,19 +27,22 @@ from ..domain_tuple import DomainTuple
from ..logger import logger
from .iteration_controller import IterationController
from .minimizer import Minimizer
from ..utilities import iscomplextype
def _multiToArray(fld):
szall = 0
for val in fld.values():
if val.dtype != np.float64:
raise TypeError("need float64 fields")
szall += val.size
szall += 2*val.size if iscomplextype(val.dtype) else val.size
res = np.empty(szall, dtype=np.float64)
ofs = 0
for val in fld.values():
res[ofs:ofs+val.size] = val.local_data.reshape(-1)
ofs += val.size
sz2 = 2*val.size if iscomplextype(val.dtype) else val.size
locdat = val.local_data.reshape(-1)
if iscomplextype(val.dtype):
locdat = locdat.astype(np.complex128).view(np.float64)
res[ofs:ofs+sz2] = locdat
ofs += sz2
return res
......@@ -55,16 +58,20 @@ def _toArray_rw(fld):
return _multiToArray(fld)
def _toField(arr, dom):
if isinstance(dom, DomainTuple):
return Field.from_local_data(dom, arr.reshape(dom.shape).copy())
def _toField(arr, template):
if isinstance(template, Field):
return Field.from_local_data(template.domain,
arr.reshape(dom.shape).copy())
ofs = 0
res = []
for d in dom.domains():
res.append(Field.from_local_data(
d, arr[ofs:ofs+d.size].copy().reshape(d.shape)))
ofs += d.size
return MultiField(dom, tuple(res))
for v in template.values():
sz2 = 2*v.size if iscomplextype(v.dtype) else v.size
locdat = arr[ofs:ofs+sz2].copy()
if iscomplextype(v.dtype):
locdat = locdat.view(np.complex128)
res.append(Field.from_local_data(v.domain, locdat.reshape(v.shape)))
ofs += sz2
return MultiField(template.domain, tuple(res))
class _MinHelper(object):
......@@ -73,7 +80,7 @@ class _MinHelper(object):
self._domain = energy.position.domain
def _update(self, x):
pos = _toField(x, self._domain)
pos = _toField(x, self._energy.position)
if (pos != self._energy.position).any():
self._energy = self._energy.at(pos)
......@@ -87,7 +94,7 @@ class _MinHelper(object):
def hessp(self, x, p):
self._update(x)
res = self._energy.metric(_toField(p, self._domain))
res = self._energy.metric(_toField(p, self._energy.position))
return _toArray_rw(res)
......@@ -174,7 +181,7 @@ class ScipyCG(Minimizer):
self._op = op
def __call__(self, inp):
return _toArray(self._op(_toField(inp, self._op.domain)))
return _toArray(self._op(_toField(inp, energy.position)))
op = energy._A
b = _toArray(energy._b)
......@@ -189,4 +196,4 @@ class ScipyCG(Minimizer):
maxiter=self._maxiter)
stat = (IterationController.CONVERGED if stat >= 0 else
IterationController.ERROR)
return energy.at(_toField(res, op.domain)), stat
return energy.at(_toField(res, energy.position)), stat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment