inversion_enabler.py 2.05 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
15 16 17 18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from ..minimization.quadratic_energy import QuadraticEnergy
20 21
from ..minimization.iteration_controller import IterationController
from ..field import Field, dobj
Martin Reinecke's avatar
Martin Reinecke committed
22
from .linear_operator import LinearOperator
23 24


Martin Reinecke's avatar
Martin Reinecke committed
25 26
class InversionEnabler(LinearOperator):
    def __init__(self, op, inverter, preconditioner=None):
27
        super(InversionEnabler, self).__init__()
Martin Reinecke's avatar
Martin Reinecke committed
28
        self._op = op
29 30
        self._inverter = inverter
        self._preconditioner = preconditioner
31

Martin Reinecke's avatar
Martin Reinecke committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    @property
    def domain(self):
        return self._op.domain

    @property
    def target(self):
        return self._op.target

    @property
    def capability(self):
        return self._addInverse[self._op.capability]

    def apply(self, x, mode):
        self._check_mode(mode)
        if self._op.capability & mode:
            return self._op.apply(x, mode)

        tdom = self._tgt(mode)
50
        x0 = Field.zeros(tdom, dtype=x.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
51 52 53 54 55

        def func(x):
            return self._op.apply(x, self._inverseMode[mode])

        energy = QuadraticEnergy(A=func, b=x, position=x0)
56 57 58
        r, stat = self._inverter(energy, preconditioner=self._preconditioner)
        if stat != IterationController.CONVERGED:
            dobj.mprint("Error detected during operator inversion")
59
        return r.position