scipy_minimizer.py 6.28 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19

20
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..logger import logger
Philipp Arras's avatar
Philipp Arras committed
23
24
from ..multi_field import MultiField
from ..utilities import iscomplextype
Martin Reinecke's avatar
fix    
Martin Reinecke committed
25
from .iteration_controllers import IterationController
26
from .minimizer import Minimizer
27
28


29
def _multiToArray(fld):
Martin Reinecke's avatar
Martin Reinecke committed
30
31
    szall = sum(2*v.size if iscomplextype(v.dtype) else v.size
                for v in fld.values())
32
33
34
    res = np.empty(szall, dtype=np.float64)
    ofs = 0
    for val in fld.values():
35
36
37
        sz2 = 2*val.size if iscomplextype(val.dtype) else val.size
        locdat = val.local_data.reshape(-1)
        if iscomplextype(val.dtype):
Martin Reinecke's avatar
Martin Reinecke committed
38
            locdat = locdat.view(locdat.real.dtype)
39
40
        res[ofs:ofs+sz2] = locdat
        ofs += sz2
41
42
43
    return res


Martin Reinecke's avatar
Martin Reinecke committed
44
def _toArray(fld):
45
46
47
    if isinstance(fld, Field):
        return fld.local_data.reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
48
49


Martin Reinecke's avatar
Martin Reinecke committed
50
def _toArray_rw(fld):
51
52
53
    if isinstance(fld, Field):
        return fld.local_data.copy().reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
54
55


56
57
58
def _toField(arr, template):
    if isinstance(template, Field):
        return Field.from_local_data(template.domain,
Martin Reinecke's avatar
fix    
Martin Reinecke committed
59
                                     arr.reshape(template.shape).copy())
60
61
    ofs = 0
    res = []
62
63
64
65
66
67
68
69
    for v in template.values():
        sz2 = 2*v.size if iscomplextype(v.dtype) else v.size
        locdat = arr[ofs:ofs+sz2].copy()
        if iscomplextype(v.dtype):
            locdat = locdat.view(np.complex128)
        res.append(Field.from_local_data(v.domain, locdat.reshape(v.shape)))
        ofs += sz2
    return MultiField(template.domain, tuple(res))
Martin Reinecke's avatar
Martin Reinecke committed
70
71


72
73
74
75
76
77
class _MinHelper(object):
    def __init__(self, energy):
        self._energy = energy
        self._domain = energy.position.domain

    def _update(self, x):
78
        pos = _toField(x, self._energy.position)
Martin Reinecke's avatar
Martin Reinecke committed
79
        if (pos != self._energy.position).any():
80
            self._energy = self._energy.at(pos)
81
82
83
84
85
86
87

    def fun(self, x):
        self._update(x)
        return self._energy.value

    def jac(self, x):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
88
        return _toArray_rw(self._energy.gradient)
89
90
91

    def hessp(self, x, p):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
92
        res = self._energy.apply_metric(_toField(p, self._energy.position))
Martin Reinecke's avatar
Martin Reinecke committed
93
        return _toArray_rw(res)
Martin Reinecke's avatar
Martin Reinecke committed
94
95
96
97
98
99
100
101
102
103
104
105
106


class ScipyMinimizer(Minimizer):
    """Scipy-based minimizer

    Parameters
    ----------
    method     : str
        The selected Scipy minimization method.
    options    : dictionary
        A set of custom options for the selected minimizer.
    """

107
    def __init__(self, method, options, need_hessp, bounds):
Martin Reinecke's avatar
Martin Reinecke committed
108
109
110
111
112
        if not dobj.is_numpy():
            raise NotImplementedError
        self._method = method
        self._options = options
        self._need_hessp = need_hessp
113
        self._bounds = bounds
Martin Reinecke's avatar
Martin Reinecke committed
114
115
116

    def __call__(self, energy):
        import scipy.optimize as opt
117
118
119
120
121
122
123
124
        hlp = _MinHelper(energy)
        energy = None  # drop handle, since we don't need it any more
        bounds = None
        if self._bounds is not None:
            if len(self._bounds) == 2:
                lo = self._bounds[0]
                hi = self._bounds[1]
                bounds = [(lo, hi)]*hlp._energy.position.size
Martin Reinecke's avatar
Martin Reinecke committed
125
            else:
126
127
                raise ValueError("unrecognized bounds")

Martin Reinecke's avatar
Martin Reinecke committed
128
        x = _toArray_rw(hlp._energy.position)
129
130
131
        hessp = hlp.hessp if self._need_hessp else None
        r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
                         hessp=hessp, options=self._options, bounds=bounds)
Martin Reinecke's avatar
Martin Reinecke committed
132
        if not r.success:
133
            logger.error("Problem in Scipy minimization: {}".format(r.message))
134
135
            return hlp._energy, IterationController.ERROR
        return hlp._energy, IterationController.CONVERGED
Martin Reinecke's avatar
Martin Reinecke committed
136
137


138
def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None):
Martin Reinecke's avatar
Martin Reinecke committed
139
140
141
142
143
144
    """Returns a ScipyMinimizer object carrying out the L-BFGS-B algorithm.

    See Also
    --------
    ScipyMinimizer
    """
145
    options = {"ftol": ftol, "gtol": gtol, "maxiter": maxiter,
Martin Reinecke's avatar
fix    
Martin Reinecke committed
146
               "maxcor": maxcor, "disp": disp}
147
    return ScipyMinimizer("L-BFGS-B", options, False, bounds)
148
149
150


class ScipyCG(Minimizer):
Martin Reinecke's avatar
Martin Reinecke committed
151
152
153
154
155
    """Returns a ScipyMinimizer object carrying out the conjugate gradient
    algorithm as implemented by SciPy.

    This class is probably superfluous and can be removed.
    """
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    def __init__(self, tol, maxiter):
        if not dobj.is_numpy():
            raise NotImplementedError
        self._tol = tol
        self._maxiter = maxiter

    def __call__(self, energy, preconditioner=None):
        from scipy.sparse.linalg import LinearOperator as scipy_linop, cg
        from .quadratic_energy import QuadraticEnergy
        if not isinstance(energy, QuadraticEnergy):
            raise ValueError("need a quadratic energy for CG")

        class mymatvec(object):
            def __init__(self, op):
                self._op = op

            def __call__(self, inp):
173
                return _toArray(self._op(_toField(inp, energy.position)))
174
175

        op = energy._A
Martin Reinecke's avatar
Martin Reinecke committed
176
177
        b = _toArray(energy._b)
        sx = _toArray(energy.position)
178
179
180
181
182
183
184
        sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
                             matvec=mymatvec(op))
        prec_op = None
        if preconditioner is not None:
            prec_op = scipy_linop(shape=(op.domain.size, op.target.size),
                                  matvec=mymatvec(preconditioner))
        res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op,
Martin Reinecke's avatar
Martin Reinecke committed
185
                       maxiter=self._maxiter, atol='legacy')
186
187
        stat = (IterationController.CONVERGED if stat >= 0 else
                IterationController.ERROR)
188
        return energy.at(_toField(res, energy.position)), stat