propagator_operator.py 3.06 KB
Newer Older
1
2
# -*- coding: utf-8 -*-
from nifty.minimization import ConjugateGradient
3
4
5
from nifty.nifty_utilities import get_default_codomain
from nifty.operators import EndomorphicOperator,\
                            FFTOperator
6

7
8
9
import logging
logger = logging.getLogger('NIFTy.PropagatorOperator')

10

11
class PropagatorOperator(EndomorphicOperator):
12
13
14

    # ---Overwritten properties and methods---

15
16
    def __init__(self, S=None, M=None, R=None, N=None, inverter=None,
                 preconditioner=None):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        """
            Sets the standard operator properties and `codomain`, `_A1`, `_A2`,
            and `RN` if required.

            Parameters
            ----------
            S : operator
                Covariance of the signal prior.
            M : operator
                Likelihood contribution.
            R : operator
                Response operator translating signal to (noiseless) data.
            N : operator
                Covariance of the noise prior or the likelihood, respectively.

        """
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        # infer domain, and target
        if M is not None:
            self._domain = M.domain
            self._likelihood_times = M.times

        elif N is None:
            raise ValueError("Either M or N must be given!")

        elif R is not None:
            self._domain = R.domain
            fft_RN = FFTOperator(self._domain, target=N.domain)
            self._likelihood_times = \
                lambda z: R.adjoint_times(
                            fft_RN.inverse_times(N.inverse_times(
                                fft_RN(R.times(z)))))
        else:
            self._domain = (get_default_codomain(N.domain[0]),)
            fft_RN = FFTOperator(self._domain, target=N.domain)
            self._likelihood_times = \
                lambda z: fft_RN.inverse_times(N.inverse_times(
                                fft_RN(z)))
54

55
56
57
58
59
60
61
62
63
        fft_S = FFTOperator(S.domain, self._domain)
        self._S_times = lambda z: fft_S.inverse_times(S(fft_S(z)))
        self._S_inverse_times = lambda z: fft_S.inverse_times(
                                            S.inverse_times(fft_S(z)))

        if preconditioner is None:
            preconditioner = self._S_times

        self.preconditioner = preconditioner
64
65
66
67

        if inverter is not None:
            self.inverter = inverter
        else:
68
69
            self.inverter = ConjugateGradient(
                                preconditioner=self.preconditioner)
70
71
72
73
74

    # ---Mandatory properties and methods---

    @property
    def domain(self):
75
        return self._domain
76
77
78

    @property
    def field_type(self):
79
        return ()
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    @property
    def implemented(self):
        return True

    @property
    def symmetric(self):
        return True

    @property
    def unitary(self):
        return False

    # ---Added properties and methods---

95
96
    def _times(self, x, spaces, types):
        (result, convergence) = self.inverter(A=self._inverse_times, b=x)
97
98
99
        return result

    def _inverse_multiply(self, x, **kwargs):
100
101
        result = self._S_inverse_times(x)
        result += self._likelihood_times(x)
102
        return result