scipy_minimizer.py 6.36 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19

20
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..logger import logger
Philipp Arras's avatar
Philipp Arras committed
23 24
from ..multi_field import MultiField
from ..utilities import iscomplextype
Martin Reinecke's avatar
fix  
Martin Reinecke committed
25
from .iteration_controllers import IterationController
26
from .minimizer import Minimizer
27 28


29
def _multiToArray(fld):
Martin Reinecke's avatar
Martin Reinecke committed
30 31
    szall = sum(2*v.size if iscomplextype(v.dtype) else v.size
                for v in fld.values())
32 33 34
    res = np.empty(szall, dtype=np.float64)
    ofs = 0
    for val in fld.values():
35 36 37
        sz2 = 2*val.size if iscomplextype(val.dtype) else val.size
        locdat = val.local_data.reshape(-1)
        if iscomplextype(val.dtype):
Martin Reinecke's avatar
Martin Reinecke committed
38
            locdat = locdat.view(locdat.real.dtype)
39 40
        res[ofs:ofs+sz2] = locdat
        ofs += sz2
41 42 43
    return res


Martin Reinecke's avatar
Martin Reinecke committed
44
def _toArray(fld):
45 46 47
    if isinstance(fld, Field):
        return fld.local_data.reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
48 49


Martin Reinecke's avatar
Martin Reinecke committed
50
def _toArray_rw(fld):
51 52 53
    if isinstance(fld, Field):
        return fld.local_data.copy().reshape(-1)
    return _multiToArray(fld)
Martin Reinecke's avatar
Martin Reinecke committed
54 55


56 57 58
def _toField(arr, template):
    if isinstance(template, Field):
        return Field.from_local_data(template.domain,
Martin Reinecke's avatar
fix  
Martin Reinecke committed
59
                                     arr.reshape(template.shape).copy())
60 61
    ofs = 0
    res = []
62 63 64 65 66 67 68 69
    for v in template.values():
        sz2 = 2*v.size if iscomplextype(v.dtype) else v.size
        locdat = arr[ofs:ofs+sz2].copy()
        if iscomplextype(v.dtype):
            locdat = locdat.view(np.complex128)
        res.append(Field.from_local_data(v.domain, locdat.reshape(v.shape)))
        ofs += sz2
    return MultiField(template.domain, tuple(res))
Martin Reinecke's avatar
Martin Reinecke committed
70 71


72 73 74 75 76 77
class _MinHelper(object):
    def __init__(self, energy):
        self._energy = energy
        self._domain = energy.position.domain

    def _update(self, x):
78
        pos = _toField(x, self._energy.position)
Martin Reinecke's avatar
Martin Reinecke committed
79
        if (pos != self._energy.position).any():
80
            self._energy = self._energy.at(pos)
81 82 83 84 85 86 87

    def fun(self, x):
        self._update(x)
        return self._energy.value

    def jac(self, x):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
88
        return _toArray_rw(self._energy.gradient)
89 90 91

    def hessp(self, x, p):
        self._update(x)
Martin Reinecke's avatar
Martin Reinecke committed
92
        res = self._energy.apply_metric(_toField(p, self._energy.position))
Martin Reinecke's avatar
Martin Reinecke committed
93
        return _toArray_rw(res)
Martin Reinecke's avatar
Martin Reinecke committed
94 95


Martin Reinecke's avatar
Martin Reinecke committed
96
class _ScipyMinimizer(Minimizer):
Martin Reinecke's avatar
Martin Reinecke committed
97 98 99 100 101 102 103 104 105 106
    """Scipy-based minimizer

    Parameters
    ----------
    method     : str
        The selected Scipy minimization method.
    options    : dictionary
        A set of custom options for the selected minimizer.
    """

107
    def __init__(self, method, options, need_hessp, bounds):
Martin Reinecke's avatar
Martin Reinecke committed
108 109 110 111 112
        if not dobj.is_numpy():
            raise NotImplementedError
        self._method = method
        self._options = options
        self._need_hessp = need_hessp
113
        self._bounds = bounds
Martin Reinecke's avatar
Martin Reinecke committed
114 115 116

    def __call__(self, energy):
        import scipy.optimize as opt
117 118 119 120 121 122 123 124
        hlp = _MinHelper(energy)
        energy = None  # drop handle, since we don't need it any more
        bounds = None
        if self._bounds is not None:
            if len(self._bounds) == 2:
                lo = self._bounds[0]
                hi = self._bounds[1]
                bounds = [(lo, hi)]*hlp._energy.position.size
Martin Reinecke's avatar
Martin Reinecke committed
125
            else:
126 127
                raise ValueError("unrecognized bounds")

Martin Reinecke's avatar
Martin Reinecke committed
128
        x = _toArray_rw(hlp._energy.position)
129 130 131
        hessp = hlp.hessp if self._need_hessp else None
        r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
                         hessp=hessp, options=self._options, bounds=bounds)
Martin Reinecke's avatar
Martin Reinecke committed
132
        if not r.success:
133
            logger.error("Problem in Scipy minimization: {}".format(r.message))
134 135
            return hlp._energy, IterationController.ERROR
        return hlp._energy, IterationController.CONVERGED
Martin Reinecke's avatar
Martin Reinecke committed
136 137


138
def L_BFGS_B(ftol, gtol, maxiter, maxcor=10, disp=False, bounds=None):
Martin Reinecke's avatar
Martin Reinecke committed
139
    """Returns a _ScipyMinimizer object carrying out the L-BFGS-B algorithm.
Martin Reinecke's avatar
Martin Reinecke committed
140 141 142

    See Also
    --------
Martin Reinecke's avatar
Martin Reinecke committed
143
    _ScipyMinimizer
Martin Reinecke's avatar
Martin Reinecke committed
144
    """
145
    options = {"ftol": ftol, "gtol": gtol, "maxiter": maxiter,
Martin Reinecke's avatar
fix  
Martin Reinecke committed
146
               "maxcor": maxcor, "disp": disp}
Martin Reinecke's avatar
Martin Reinecke committed
147
    return _ScipyMinimizer("L-BFGS-B", options, False, bounds)
148 149


Martin Reinecke's avatar
Martin Reinecke committed
150 151
class _ScipyCG(Minimizer):
    """Returns a _ScipyMinimizer object carrying out the conjugate gradient
Martin Reinecke's avatar
Martin Reinecke committed
152 153
    algorithm as implemented by SciPy.

Martin Reinecke's avatar
Martin Reinecke committed
154 155
    This class is only intended for double-checking NIFTy's own conjugate
    gradient implementation and should not be used otherwise.
Martin Reinecke's avatar
Martin Reinecke committed
156
    """
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    def __init__(self, tol, maxiter):
        if not dobj.is_numpy():
            raise NotImplementedError
        self._tol = tol
        self._maxiter = maxiter

    def __call__(self, energy, preconditioner=None):
        from scipy.sparse.linalg import LinearOperator as scipy_linop, cg
        from .quadratic_energy import QuadraticEnergy
        if not isinstance(energy, QuadraticEnergy):
            raise ValueError("need a quadratic energy for CG")

        class mymatvec(object):
            def __init__(self, op):
                self._op = op

            def __call__(self, inp):
174
                return _toArray(self._op(_toField(inp, energy.position)))
175 176

        op = energy._A
Martin Reinecke's avatar
Martin Reinecke committed
177 178
        b = _toArray(energy._b)
        sx = _toArray(energy.position)
179 180 181 182 183 184 185
        sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
                             matvec=mymatvec(op))
        prec_op = None
        if preconditioner is not None:
            prec_op = scipy_linop(shape=(op.domain.size, op.target.size),
                                  matvec=mymatvec(preconditioner))
        res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op,
Martin Reinecke's avatar
Martin Reinecke committed
186
                       maxiter=self._maxiter, atol='legacy')
187 188
        stat = (IterationController.CONVERGED if stat >= 0 else
                IterationController.ERROR)
189
        return energy.at(_toField(res, energy.position)), stat