scipy_minimizer.py 6.11 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19

Martin Reinecke's avatar
Martin Reinecke committed
20
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..logger import logger
Philipp Arras's avatar
Philipp Arras committed
22
23
from ..multi_field import MultiField
from ..utilities import iscomplextype
Martin Reinecke's avatar
fix    
Martin Reinecke committed
24
from .iteration_controllers import IterationController
25
from .minimizer import Minimizer
26
27


28
def _multiToArray(fld):
Martin Reinecke's avatar
Martin Reinecke committed
29
30
    szall = sum(2*v.size if iscomplextype(v.dtype) else v.size
                for v in fld.values())
31
32
33
    res = np.empty(szall, dtype=np.float64)
    ofs = 0
    for val in fld.values():
34
        sz2 = 2*val.size if iscomplextype(val.dtype) else val.size
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
35
        locdat = val.val.reshape(-1)
36
        if iscomplextype(val.dtype):
Martin Reinecke's avatar
Martin Reinecke committed
37
            locdat = locdat.view(locdat.real.dtype)
38
39
        res[ofs:ofs+sz2] = locdat
        ofs += sz2
40
41
42
    return res


Martin Reinecke's avatar
Martin Reinecke committed
43
def _toArray(fld):
44
    if isinstance(fld, Field):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
45
        return fld.val.reshape(-1)
46
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
47
48


Martin Reinecke's avatar
Martin Reinecke committed
49
def _toArray_rw(fld):
50
    if isinstance(fld, Field):
Martin Reinecke's avatar
Martin Reinecke committed
51
        return fld.val_rw().reshape(-1)
52
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
53
54


55
56
def _toField(arr, template):
    if isinstance(template, Field):
57
        return Field(template.target, arr.reshape(template.shape).copy())
58
59
    ofs = 0
    res = []
60
61
62
63
64
    for v in template.values():
        sz2 = 2*v.size if iscomplextype(v.dtype) else v.size
        locdat = arr[ofs:ofs+sz2].copy()
        if iscomplextype(v.dtype):
            locdat = locdat.view(np.complex128)
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
65
        res.append(Field(v.domain, locdat.reshape(v.shape)))
66
67
        ofs += sz2
    return MultiField(template.domain, tuple(res))
Martin Reinecke's avatar
Martin Reinecke committed
68
69


70
71
72
class _MinHelper(object):
    def __init__(self, energy):
        self._energy = energy
73
        self._domain = energy.position.target
74
75

    def _update(self, x):
76
        pos = _toField(x, self._energy.position)
Martin Reinecke's avatar
Martin Reinecke committed
77
        if (pos != self._energy.position).s_any():
78
            self._energy = self._energy.at(pos)
79
80
81
82
83
84
85

    def fun(self, x):
        self._update(x)
        return self._energy.value

    def jac(self, x):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
86
        return _toArray_rw(self._energy.gradient)
87
88
89

    def hessp(self, x, p):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
90
        res = self._energy.apply_metric(_toField(p, self._energy.position))
Martin Reinecke's avatar
Martin Reinecke committed
91
        return _toArray_rw(res)
Martin Reinecke's avatar
Martin Reinecke committed
92
93


Martin Reinecke's avatar
Martin Reinecke committed
94
class _ScipyMinimizer(Minimizer):
Martin Reinecke's avatar
Martin Reinecke committed
95
96
97
98
99
100
101
102
103
104
    """Scipy-based minimizer

    Parameters
    ----------
    method     : str
        The selected Scipy minimization method.
    options    : dictionary
        A set of custom options for the selected minimizer.
    """

105
    def __init__(self, method, options, need_hessp, bounds):
Martin Reinecke's avatar
Martin Reinecke committed
106
107
108
        self._method = method
        self._options = options
        self._need_hessp = need_hessp
109
        self._bounds = bounds
Martin Reinecke's avatar
Martin Reinecke committed
110
111
112

    def __call__(self, energy):
        import scipy.optimize as opt
113
114
115
116
117
118
119
120
        hlp = _MinHelper(energy)
        energy = None  # drop handle, since we don't need it any more
        bounds = None
        if self._bounds is not None:
            if len(self._bounds) == 2:
                lo = self._bounds[0]
                hi = self._bounds[1]
                bounds = [(lo, hi)]*hlp._energy.position.size
Martin Reinecke's avatar
Martin Reinecke committed
121
            else:
122
123
                raise ValueError("unrecognized bounds")

Martin Reinecke's avatar
Martin Reinecke committed
124
        x = _toArray_rw(hlp._energy.position)
125
126
127
        hessp = hlp.hessp if self._need_hessp else None
        r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
                         hessp=hessp, options=self._options, bounds=bounds)
Martin Reinecke's avatar
Martin Reinecke committed
128
        if not r.success:
129
            logger.error("Problem in Scipy minimization: {}".format(r.message))
130
131
            return hlp._energy, IterationController.ERROR
        return hlp._energy, IterationController.CONVERGED
Martin Reinecke's avatar
Martin Reinecke committed
132
133


134
def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None):
Martin Reinecke's avatar
Martin Reinecke committed
135
    """Returns a _ScipyMinimizer object carrying out the L-BFGS-B algorithm.
Martin Reinecke's avatar
Martin Reinecke committed
136
137
138

    See Also
    --------
Martin Reinecke's avatar
Martin Reinecke committed
139
    _ScipyMinimizer
Martin Reinecke's avatar
Martin Reinecke committed
140
    """
141
    options = {"ftol": ftol, "gtol": gtol, "maxiter": maxiter,
Martin Reinecke's avatar
fix    
Martin Reinecke committed
142
               "maxcor": maxcor, "disp": disp}
Martin Reinecke's avatar
Martin Reinecke committed
143
    return _ScipyMinimizer("L-BFGS-B", options, False, bounds)
144
145


Martin Reinecke's avatar
Martin Reinecke committed
146
147
class _ScipyCG(Minimizer):
    """Returns a _ScipyMinimizer object carrying out the conjugate gradient
Martin Reinecke's avatar
Martin Reinecke committed
148
149
    algorithm as implemented by SciPy.

Martin Reinecke's avatar
Martin Reinecke committed
150
151
    This class is only intended for double-checking NIFTy's own conjugate
    gradient implementation and should not be used otherwise.
Martin Reinecke's avatar
Martin Reinecke committed
152
    """
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    def __init__(self, tol, maxiter):
        self._tol = tol
        self._maxiter = maxiter

    def __call__(self, energy, preconditioner=None):
        from scipy.sparse.linalg import LinearOperator as scipy_linop, cg
        from .quadratic_energy import QuadraticEnergy
        if not isinstance(energy, QuadraticEnergy):
            raise ValueError("need a quadratic energy for CG")

        class mymatvec(object):
            def __init__(self, op):
                self._op = op

            def __call__(self, inp):
168
                return _toArray(self._op(_toField(inp, energy.position)))
169
170

        op = energy._A
Martin Reinecke's avatar
Martin Reinecke committed
171
172
        b = _toArray(energy._b)
        sx = _toArray(energy.position)
173
174
175
176
177
178
179
        sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
                             matvec=mymatvec(op))
        prec_op = None
        if preconditioner is not None:
            prec_op = scipy_linop(shape=(op.domain.size, op.target.size),
                                  matvec=mymatvec(preconditioner))
        res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op,
Martin Reinecke's avatar
Martin Reinecke committed
180
                       maxiter=self._maxiter, atol='legacy')
181
182
        stat = (IterationController.CONVERGED if stat >= 0 else
                IterationController.ERROR)
183
        return energy.at(_toField(res, energy.position)), stat