From 455b0dd97b4b8615f0e8cc003ce4d8add977b438 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Fri, 10 Aug 2018 15:15:14 +0200 Subject: [PATCH] cosmetics; start working on scipy minimizers and multifields --- demos/getting_started_3.py | 2 ++ nifty5/minimization/scipy_minimizer.py | 35 +++++++++++++++++++++-- nifty5/operators/operator.py | 16 +++++------ test/test_minimization/test_minimizers.py | 5 +--- 4 files changed, 42 insertions(+), 16 deletions(-) diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index b0e90e3a2..7b2c33888 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -76,6 +76,8 @@ if __name__ == '__main__': ic_sampling = ift.GradientNormController(iteration_limit=100) ic_newton = ift.GradientNormController(name='Newton', iteration_limit=100) minimizer = ift.RelaxedNewton(ic_newton) + minimizer = ift.NewtonCG(1e-5, 10, True) + minimizer = ift.L_BFGS_B(1e-10, 1e-5, 100, 10, True) # build model Hamiltonian H = ift.Hamiltonian(likelihood, ic_sampling) diff --git a/nifty5/minimization/scipy_minimizer.py b/nifty5/minimization/scipy_minimizer.py index 057355b0d..0727b9b3c 100644 --- a/nifty5/minimization/scipy_minimizer.py +++ b/nifty5/minimization/scipy_minimizer.py @@ -18,24 +18,53 @@ from __future__ import absolute_import, division, print_function +import numpy as np from .. import dobj from ..compat import * from ..field import Field +from ..multi_field import MultiField +from ..domain_tuple import DomainTuple from ..logger import logger from .iteration_controller import IterationController from .minimizer import Minimizer +def _multiToArray(fld): + szall = 0 + for val in fld.values(): + if val.dtype != np.float64: + raise TypeError("need float64 fields") + szall += val.size + res = np.empty(szall, dtype=np.float64) + ofs = 0 + for val in fld.values(): + res[ofs:ofs+val.size] = val.local_data.reshape(-1) + ofs += val.size + return res + + def _toArray(fld): - return fld.to_global_data().reshape(-1) + if isinstance(fld, Field): + return fld.local_data.reshape(-1) + return _multiToArray(fld) def _toArray_rw(fld): - return fld.to_global_data_rw().reshape(-1) + if isinstance(fld, Field): + return fld.local_data.copy().reshape(-1) + return _multiToArray(fld) def _toField(arr, dom): - return Field.from_global_data(dom, arr.reshape(dom.shape).copy()) + if isinstance(dom, DomainTuple): + return Field.from_local_data(dom, arr.reshape(dom.shape).copy()) + ofs = 0 + res = [] + for d in dom.domains(): + res.append(Field.from_local_data( + d, arr[ofs:ofs+d.size].copy().reshape(d.shape))) + ofs += d.size + return MultiField(dom, tuple(res)) class _MinHelper(object): diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 04f91b1ec..11f979ac5 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -132,18 +132,16 @@ class _OpProd(Operator): from ..linearization import Linearization from ..sugar import makeOp lin = isinstance(x, Linearization) + v = x._val if lin else x + v1 = v.extract(self._op1.domain) + v2 = v.extract(self._op2.domain) if not lin: - r1 = self._op1(x.extract(self._op1.domain)) - r2 = self._op2(x.extract(self._op2.domain)) - return r1*r2 - lin1 = self._op1( - Linearization.make_var(x._val.extract(self._op1.domain))) - lin2 = self._op2( - Linearization.make_var(x._val.extract(self._op2.domain))) + return self._op1(v1) * self._op2(v2) + lin1 = self._op1(Linearization.make_var(v1)) + lin2 = self._op2(Linearization.make_var(v2)) op = (makeOp(lin1._val)(lin2._jac))._myadd( makeOp(lin2._val)(lin1._jac), False) - jac = op(x.jac) - return Linearization(lin1._val*lin2._val, jac) + return Linearization(lin1._val*lin2._val, op(x.jac)) class _OpSum(_CombinedOperator): diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index 35be855ca..e21051327 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -100,10 +100,7 @@ class Test_Minimizers(unittest.TestCase): def __init__(self, loc): self._loc = loc.to_global_data_rw() self._capability = self.TIMES - - @property - def domain(self): - return space + self._domain = space def apply(self, x, mode): self._check_input(x, mode) -- GitLab