inversion_enabler.py 3.33 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
15 16 17 18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

21
import numpy as np
22 23

from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
24
from ..logger import logger
25 26 27
from ..minimization.conjugate_gradient import ConjugateGradient
from ..minimization.iteration_controller import IterationController
from ..minimization.quadratic_energy import QuadraticEnergy
28
from ..sugar import full
29
from .endomorphic_operator import EndomorphicOperator
30 31


32
class InversionEnabler(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
33 34
    """Class which augments the capability of another operator object via
    numerical inversion.
35 36 37 38 39 40 41 42

    Parameters
    ----------
    op : :class:`EndomorphicOperator`
        The operator to be enhanced.
        The InversionEnabler object will support the same operation modes as
        `op`, and additionally the inverse set. The newly-added modes will
        be computed by iterative inversion.
43 44 45
    iteration_controller : :class:`IterationController`
        The iteration controller to use for the iterative numerical inversion
        done by a :class:`ConjugateGradient` object.
46 47 48 49 50
    approximation : :class:`LinearOperator`, optional
        if not None, this operator should be an approximation to `op`, which
        supports the operation modes that `op` doesn't have. It is used as a
        preconditioner during the iterative inversion, to accelerate
        convergence.
Martin Reinecke's avatar
Martin Reinecke committed
51 52
    """

53
    def __init__(self, op, iteration_controller, approximation=None):
54
        super(InversionEnabler, self).__init__()
Martin Reinecke's avatar
Martin Reinecke committed
55
        self._op = op
56
        self._ic = iteration_controller
57
        self._approximation = approximation
58

Martin Reinecke's avatar
Martin Reinecke committed
59 60 61 62 63 64 65 66 67 68 69 70 71
    @property
    def domain(self):
        return self._op.domain

    @property
    def capability(self):
        return self._addInverse[self._op.capability]

    def apply(self, x, mode):
        self._check_mode(mode)
        if self._op.capability & mode:
            return self._op.apply(x, mode)

72
        x0 = full(x.domain, 0.)
Martin Reinecke's avatar
Martin Reinecke committed
73 74
        invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
        invop = self._op._flip_modes(self._ilog[invmode])
75
        prec = self._approximation
Martin Reinecke's avatar
Martin Reinecke committed
76 77
        if prec is not None:
            prec = prec._flip_modes(self._ilog[mode])
Martin Reinecke's avatar
Martin Reinecke committed
78
        energy = QuadraticEnergy(x0, invop, x)
79 80
        inverter = ConjugateGradient(self._ic)
        r, stat = inverter(energy, preconditioner=prec)
81
        if stat != IterationController.CONVERGED:
Martin Reinecke's avatar
Martin Reinecke committed
82
            logger.warning("Error detected during operator inversion")
83
        return r.position
Martin Reinecke's avatar
Martin Reinecke committed
84

85
    def draw_sample(self, from_inverse=False, dtype=np.float64):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
86
        return self._op.draw_sample(from_inverse, dtype)