scipy_minimizer.py 6.59 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
from __future__ import absolute_import, division, print_function
20

21
import numpy as np
22
from .. import dobj
23
from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
24
from ..field import Field
25
26
from ..multi_field import MultiField
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
27
from ..logger import logger
28
from .iteration_controller import IterationController
29
from .minimizer import Minimizer
30
from ..utilities import iscomplextype
31
32


33
34
35
def _multiToArray(fld):
    szall = 0
    for val in fld.values():
36
        szall += 2*val.size if iscomplextype(val.dtype) else val.size
37
38
39
    res = np.empty(szall, dtype=np.float64)
    ofs = 0
    for val in fld.values():
40
41
42
43
44
45
        sz2 = 2*val.size if iscomplextype(val.dtype) else val.size
        locdat = val.local_data.reshape(-1)
        if iscomplextype(val.dtype):
            locdat = locdat.astype(np.complex128).view(np.float64)
        res[ofs:ofs+sz2] = locdat
        ofs += sz2
46
47
48
    return res


Martin Reinecke's avatar
Martin Reinecke committed
49
def _toArray(fld):
50
51
52
    if isinstance(fld, Field):
        return fld.local_data.reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
53
54


Martin Reinecke's avatar
Martin Reinecke committed
55
def _toArray_rw(fld):
56
57
58
    if isinstance(fld, Field):
        return fld.local_data.copy().reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
59
60


61
62
63
def _toField(arr, template):
    if isinstance(template, Field):
        return Field.from_local_data(template.domain,
Martin Reinecke's avatar
fix    
Martin Reinecke committed
64
                                     arr.reshape(template.shape).copy())
65
66
    ofs = 0
    res = []
67
68
69
70
71
72
73
74
    for v in template.values():
        sz2 = 2*v.size if iscomplextype(v.dtype) else v.size
        locdat = arr[ofs:ofs+sz2].copy()
        if iscomplextype(v.dtype):
            locdat = locdat.view(np.complex128)
        res.append(Field.from_local_data(v.domain, locdat.reshape(v.shape)))
        ofs += sz2
    return MultiField(template.domain, tuple(res))
Martin Reinecke's avatar
Martin Reinecke committed
75
76


77
78
79
80
81
82
class _MinHelper(object):
    def __init__(self, energy):
        self._energy = energy
        self._domain = energy.position.domain

    def _update(self, x):
83
        pos = _toField(x, self._energy.position)
Martin Reinecke's avatar
Martin Reinecke committed
84
        if (pos != self._energy.position).any():
85
            self._energy = self._energy.at(pos)
86
87
88
89
90
91
92

    def fun(self, x):
        self._update(x)
        return self._energy.value

    def jac(self, x):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
93
        return _toArray_rw(self._energy.gradient)
94
95
96

    def hessp(self, x, p):
        self._update(x)
97
        res = self._energy.metric(_toField(p, self._energy.position))
Martin Reinecke's avatar
Martin Reinecke committed
98
        return _toArray_rw(res)
Martin Reinecke's avatar
Martin Reinecke committed
99
100
101
102
103
104
105
106
107
108
109
110
111


class ScipyMinimizer(Minimizer):
    """Scipy-based minimizer

    Parameters
    ----------
    method     : str
        The selected Scipy minimization method.
    options    : dictionary
        A set of custom options for the selected minimizer.
    """

112
    def __init__(self, method, options, need_hessp, bounds):
Martin Reinecke's avatar
Martin Reinecke committed
113
114
115
116
117
        if not dobj.is_numpy():
            raise NotImplementedError
        self._method = method
        self._options = options
        self._need_hessp = need_hessp
118
        self._bounds = bounds
Martin Reinecke's avatar
Martin Reinecke committed
119
120
121

    def __call__(self, energy):
        import scipy.optimize as opt
122
123
124
125
126
127
128
129
        hlp = _MinHelper(energy)
        energy = None  # drop handle, since we don't need it any more
        bounds = None
        if self._bounds is not None:
            if len(self._bounds) == 2:
                lo = self._bounds[0]
                hi = self._bounds[1]
                bounds = [(lo, hi)]*hlp._energy.position.size
Martin Reinecke's avatar
Martin Reinecke committed
130
            else:
131
132
                raise ValueError("unrecognized bounds")

Martin Reinecke's avatar
Martin Reinecke committed
133
        x = _toArray_rw(hlp._energy.position)
134
135
136
        hessp = hlp.hessp if self._need_hessp else None
        r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
                         hessp=hessp, options=self._options, bounds=bounds)
Martin Reinecke's avatar
Martin Reinecke committed
137
        if not r.success:
138
            logger.error("Problem in Scipy minimization: {}".format(r.message))
139
140
            return hlp._energy, IterationController.ERROR
        return hlp._energy, IterationController.CONVERGED
Martin Reinecke's avatar
Martin Reinecke committed
141
142


143
def NewtonCG(xtol, maxiter, disp=False):
Martin Reinecke's avatar
Martin Reinecke committed
144
145
146
147
148
149
    """Returns a ScipyMinimizer object carrying out the Newton-CG algorithm.

    See Also
    --------
    ScipyMinimizer
    """
Martin Reinecke's avatar
fix    
Martin Reinecke committed
150
    options = {"xtol": xtol, "maxiter": maxiter, "disp": disp}
151
    return ScipyMinimizer("Newton-CG", options, True, None)
Martin Reinecke's avatar
Martin Reinecke committed
152
153


154
def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None):
Martin Reinecke's avatar
Martin Reinecke committed
155
156
157
158
159
160
    """Returns a ScipyMinimizer object carrying out the L-BFGS-B algorithm.

    See Also
    --------
    ScipyMinimizer
    """
161
    options = {"ftol": ftol, "gtol": gtol, "maxiter": maxiter,
Martin Reinecke's avatar
fix    
Martin Reinecke committed
162
               "maxcor": maxcor, "disp": disp}
163
    return ScipyMinimizer("L-BFGS-B", options, False, bounds)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183


class ScipyCG(Minimizer):
    def __init__(self, tol, maxiter):
        if not dobj.is_numpy():
            raise NotImplementedError
        self._tol = tol
        self._maxiter = maxiter

    def __call__(self, energy, preconditioner=None):
        from scipy.sparse.linalg import LinearOperator as scipy_linop, cg
        from .quadratic_energy import QuadraticEnergy
        if not isinstance(energy, QuadraticEnergy):
            raise ValueError("need a quadratic energy for CG")

        class mymatvec(object):
            def __init__(self, op):
                self._op = op

            def __call__(self, inp):
184
                return _toArray(self._op(_toField(inp, energy.position)))
185
186

        op = energy._A
Martin Reinecke's avatar
Martin Reinecke committed
187
188
        b = _toArray(energy._b)
        sx = _toArray(energy.position)
189
190
191
192
193
194
195
196
197
198
        sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
                             matvec=mymatvec(op))
        prec_op = None
        if preconditioner is not None:
            prec_op = scipy_linop(shape=(op.domain.size, op.target.size),
                                  matvec=mymatvec(preconditioner))
        res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op,
                       maxiter=self._maxiter)
        stat = (IterationController.CONVERGED if stat >= 0 else
                IterationController.ERROR)
199
        return energy.at(_toField(res, energy.position)), stat