log_normal_wiener_filter_curvature.py 2.31 KB
Newer Older
1 2
from nifty.operators import EndomorphicOperator,\
                            InvertibleOperatorMixin
3
from nifty.energies.memoization import memo
4 5
from nifty.basic_arithmetics import exp

6 7 8

class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
                                     EndomorphicOperator):
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
    """The curvature of the LogNormalWienerFilterEnergy.

    This operator implements the second derivative of the
    LogNormalWienerFilterEnergy used in some minimization algorithms or for
    error estimates of the posterior maps. It is the inverse of the propagator
    operator.

    Parameters
    ----------
    R: LinearOperator,
        The response operator of the Wiener filter measurement.
    N : EndomorphicOperator
        The noise covariance.
    S: DiagonalOperator,
        The prior signal covariance

    """

27 28 29
    def __init__(self, R, N, S, d, position, inverter=None,
                 preconditioner=None, **kwargs):
        self._cache = {}
30 31 32
        self.R = R
        self.N = N
        self.S = S
33 34
        self.d = d
        self.position = position
35 36 37
        if preconditioner is None:
            preconditioner = self.S.times
        self._domain = self.S.domain
38
        super(LogNormalWienerFilterCurvature, self).__init__(
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
                                                 inverter=inverter,
                                                 preconditioner=preconditioner,
                                                 **kwargs)

    @property
    def domain(self):
        return self._domain

    @property
    def self_adjoint(self):
        return True

    @property
    def unitary(self):
        return False

    # ---Added properties and methods---

    def _times(self, x, spaces):
        part1 = self.S.inverse_times(x)
59 60 61
        part2 = self._exppRNRexppd * x
        part3 = self._expp * self.R.adjoint_times(
                                self.N.inverse_times(self.R(self._expp * x)))
62
        return part1 + part2 + part3
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82

    @property
    @memo
    def _expp(self):
        return exp(self.position)

    @property
    @memo
    def _Rexppd(self):
        return self.R(self._expp) - self.d

    @property
    @memo
    def _NRexppd(self):
        return self.N.inverse_times(self._Rexppd)

    @property
    @memo
    def _exppRNRexppd(self):
        return self._expp * self.R.adjoint_times(self._NRexppd)