invertible_operator_mixin.py 2.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19
20
21
22
23
24

from nifty.minimization import ConjugateGradient

from nifty.field import Field


class InvertibleOperatorMixin(object):
25
    def __init__(self, inverter=None, preconditioner=None, *args, **kwargs):
26
27
28
29
30
        self.__preconditioner = preconditioner
        if inverter is not None:
            self.__inverter = inverter
        else:
            self.__inverter = ConjugateGradient(
31
                                        preconditioner=self.__preconditioner)
32

33
    def _times(self, x, spaces, x0=None):
34
35
36
        if x0 is None:
            x0 = Field(self.target, val=0., dtype=x.dtype)

37
38
39
        (result, convergence) = self.__inverter(A=self.inverse_times,
                                                b=x,
                                                x0=x0)
40
41
        return result

42
    def _adjoint_times(self, x, spaces, x0=None):
43
44
45
        if x0 is None:
            x0 = Field(self.domain, val=0., dtype=x.dtype)

46
47
48
        (result, convergence) = self.__inverter(A=self.adjoint_inverse_times,
                                                b=x,
                                                x0=x0)
49
50
        return result

51
    def _inverse_times(self, x, spaces, x0=None):
52
53
54
        if x0 is None:
            x0 = Field(self.domain, val=0., dtype=x.dtype)

55
56
57
        (result, convergence) = self.__inverter(A=self.times,
                                                b=x,
                                                x0=x0)
58
59
        return result

60
    def _adjoint_inverse_times(self, x, spaces, x0=None):
61
62
63
        if x0 is None:
            x0 = Field(self.target, val=0., dtype=x.dtype)

64
65
66
        (result, convergence) = self.__inverter(A=self.adjoint_times,
                                                b=x,
                                                x0=x0)
67
68
        return result

69
    def _inverse_adjoint_times(self, x, spaces):
70
71
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")