scipy_minimizer.py 6.36 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19

20
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..logger import logger
Philipp Arras's avatar
Philipp Arras committed
23
24
from ..multi_field import MultiField
from ..utilities import iscomplextype
Martin Reinecke's avatar
fix    
Martin Reinecke committed
25
from .iteration_controllers import IterationController
26
from .minimizer import Minimizer
27
28


29
def _multiToArray(fld):
Martin Reinecke's avatar
Martin Reinecke committed
30
31
    szall = sum(2*v.size if iscomplextype(v.dtype) else v.size
                for v in fld.values())
32
33
34
    res = np.empty(szall, dtype=np.float64)
    ofs = 0
    for val in fld.values():
35
36
37
        sz2 = 2*val.size if iscomplextype(val.dtype) else val.size
        locdat = val.local_data.reshape(-1)
        if iscomplextype(val.dtype):
Martin Reinecke's avatar
Martin Reinecke committed
38
            locdat = locdat.view(locdat.real.dtype)
39
40
        res[ofs:ofs+sz2] = locdat
        ofs += sz2
41
42
43
    return res


Martin Reinecke's avatar
Martin Reinecke committed
44
def _toArray(fld):
45
46
47
    if isinstance(fld, Field):
        return fld.local_data.reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
48
49


Martin Reinecke's avatar
Martin Reinecke committed
50
def _toArray_rw(fld):
51
52
53
    if isinstance(fld, Field):
        return fld.local_data.copy().reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
54
55


56
57
58
def _toField(arr, template):
    if isinstance(template, Field):
        return Field.from_local_data(template.domain,
Martin Reinecke's avatar
fix    
Martin Reinecke committed
59
                                     arr.reshape(template.shape).copy())
60
61
    ofs = 0
    res = []
62
63
64
65
66
67
68
69
    for v in template.values():
        sz2 = 2*v.size if iscomplextype(v.dtype) else v.size
        locdat = arr[ofs:ofs+sz2].copy()
        if iscomplextype(v.dtype):
            locdat = locdat.view(np.complex128)
        res.append(Field.from_local_data(v.domain, locdat.reshape(v.shape)))
        ofs += sz2
    return MultiField(template.domain, tuple(res))
Martin Reinecke's avatar
Martin Reinecke committed
70
71


72
73
74
75
76
77
class _MinHelper(object):
    def __init__(self, energy):
        self._energy = energy
        self._domain = energy.position.domain

    def _update(self, x):
78
        pos = _toField(x, self._energy.position)
Martin Reinecke's avatar
Martin Reinecke committed
79
        if (pos != self._energy.position).any():
80
            self._energy = self._energy.at(pos)
81
82
83
84
85
86
87

    def fun(self, x):
        self._update(x)
        return self._energy.value

    def jac(self, x):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
88
        return _toArray_rw(self._energy.gradient)
89
90
91

    def hessp(self, x, p):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
92
        res = self._energy.apply_metric(_toField(p, self._energy.position))
Martin Reinecke's avatar
Martin Reinecke committed
93
        return _toArray_rw(res)
Martin Reinecke's avatar
Martin Reinecke committed
94
95


Martin Reinecke's avatar
Martin Reinecke committed
96
class _ScipyMinimizer(Minimizer):
Martin Reinecke's avatar
Martin Reinecke committed
97
98
99
100
101
102
103
104
105
106
    """Scipy-based minimizer

    Parameters
    ----------
    method     : str
        The selected Scipy minimization method.
    options    : dictionary
        A set of custom options for the selected minimizer.
    """

107
    def __init__(self, method, options, need_hessp, bounds):
Martin Reinecke's avatar
Martin Reinecke committed
108
109
110
111
112
        if not dobj.is_numpy():
            raise NotImplementedError
        self._method = method
        self._options = options
        self._need_hessp = need_hessp
113
        self._bounds = bounds
Martin Reinecke's avatar
Martin Reinecke committed
114
115
116

    def __call__(self, energy):
        import scipy.optimize as opt
117
118
119
120
121
122
123
124
        hlp = _MinHelper(energy)
        energy = None  # drop handle, since we don't need it any more
        bounds = None
        if self._bounds is not None:
            if len(self._bounds) == 2:
                lo = self._bounds[0]
                hi = self._bounds[1]
                bounds = [(lo, hi)]*hlp._energy.position.size
Martin Reinecke's avatar
Martin Reinecke committed
125
            else:
126
127
                raise ValueError("unrecognized bounds")

Martin Reinecke's avatar
Martin Reinecke committed
128
        x = _toArray_rw(hlp._energy.position)
129
130
131
        hessp = hlp.hessp if self._need_hessp else None
        r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
                         hessp=hessp, options=self._options, bounds=bounds)
Martin Reinecke's avatar
Martin Reinecke committed
132
        if not r.success:
133
            logger.error("Problem in Scipy minimization: {}".format(r.message))
134
135
            return hlp._energy, IterationController.ERROR
        return hlp._energy, IterationController.CONVERGED
Martin Reinecke's avatar
Martin Reinecke committed
136
137


138
def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None):
Martin Reinecke's avatar
Martin Reinecke committed
139
    """Returns a _ScipyMinimizer object carrying out the L-BFGS-B algorithm.
Martin Reinecke's avatar
Martin Reinecke committed
140
141
142

    See Also
    --------
Martin Reinecke's avatar
Martin Reinecke committed
143
    _ScipyMinimizer
Martin Reinecke's avatar
Martin Reinecke committed
144
    """
145
    options = {"ftol": ftol, "gtol": gtol, "maxiter": maxiter,
Martin Reinecke's avatar
fix    
Martin Reinecke committed
146
               "maxcor": maxcor, "disp": disp}
Martin Reinecke's avatar
Martin Reinecke committed
147
    return _ScipyMinimizer("L-BFGS-B", options, False, bounds)
148
149


Martin Reinecke's avatar
Martin Reinecke committed
150
151
class _ScipyCG(Minimizer):
    """Returns a _ScipyMinimizer object carrying out the conjugate gradient
Martin Reinecke's avatar
Martin Reinecke committed
152
153
    algorithm as implemented by SciPy.

Martin Reinecke's avatar
Martin Reinecke committed
154
155
    This class is only intended for double-checking NIFTy's own conjugate
    gradient implementation and should not be used otherwise.
Martin Reinecke's avatar
Martin Reinecke committed
156
    """
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    def __init__(self, tol, maxiter):
        if not dobj.is_numpy():
            raise NotImplementedError
        self._tol = tol
        self._maxiter = maxiter

    def __call__(self, energy, preconditioner=None):
        from scipy.sparse.linalg import LinearOperator as scipy_linop, cg
        from .quadratic_energy import QuadraticEnergy
        if not isinstance(energy, QuadraticEnergy):
            raise ValueError("need a quadratic energy for CG")

        class mymatvec(object):
            def __init__(self, op):
                self._op = op

            def __call__(self, inp):
174
                return _toArray(self._op(_toField(inp, energy.position)))
175
176

        op = energy._A
Martin Reinecke's avatar
Martin Reinecke committed
177
178
        b = _toArray(energy._b)
        sx = _toArray(energy.position)
179
180
181
182
183
184
185
        sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
                             matvec=mymatvec(op))
        prec_op = None
        if preconditioner is not None:
            prec_op = scipy_linop(shape=(op.domain.size, op.target.size),
                                  matvec=mymatvec(preconditioner))
        res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op,
Martin Reinecke's avatar
Martin Reinecke committed
186
                       maxiter=self._maxiter, atol='legacy')
187
188
        stat = (IterationController.CONVERGED if stat >= 0 else
                IterationController.ERROR)
189
        return energy.at(_toField(res, energy.position)), stat