starblade_energy.py 2.86 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2
from nifty4 import Energy, Field, log, exp, DiagonalOperator
from nifty4.library import WienerFilterCurvature
Jakob Knollmueller's avatar
Jakob Knollmueller committed
3
from nifty4.library.nonlinearities import PositiveTanh
Jakob Knollmueller's avatar
Jakob Knollmueller committed
4 5


Jakob Knollmueller's avatar
Jakob Knollmueller committed
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
class StarbladeEnergy(Energy):
    """The Energy for the starblade problem.

    It implements the Information Hamiltonian of the separation of d

    Parameters
    ----------
    position : Field
        The current position of the separation.
    parameters : Dictionary
        Dictionary containing all relevant quantities for the inference,
        data : Field
            The image data.
        alpha : Field
            Slope parameter of the point-source prior
        q : Field
            Cutoff parameter of the point-source prior
        correlation : Field
            A field in the Fourier space which encodes the diagonal of the prior
            correlation structure of the diffuse component
        FFT : FFTOperator
            An operator performing the Fourier transform
        inverter : ConjugateGradient
            the minimization strategy to use for operator inversion
    """
Jakob Knollmueller's avatar
Jakob Knollmueller committed
31 32 33 34 35

    def __init__(self, position, parameters):

        x = position.val.clip(-9, 9)
        position = Field(position.domain, val=x)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
36
        super(StarbladeEnergy, self).__init__(position=position)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
37 38 39 40 41 42 43 44

        self.parameters = parameters
        self.inverter = parameters['inverter']
        self.d = parameters['data']
        self.FFT = parameters['FFT']
        self.correlation = parameters['correlation']
        self.alpha = parameters['alpha']
        self.q = parameters['q']
Jakob Knollmueller's avatar
Jakob Knollmueller committed
45
        pos_tanh = PositiveTanh()
Martin Reinecke's avatar
Martin Reinecke committed
46
        self.S = self.FFT * self.correlation * self.FFT.adjoint
Jakob Knollmueller's avatar
Jakob Knollmueller committed
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        self.a = pos_tanh(self.position)
        self.a_p = pos_tanh.derivative(self.position)

        self.u = log(self.d * self.a)
        self.u_p = self.a_p/self.a
        one_m_a = 1 - self.a
        self.s = log(self.d * one_m_a)
        self.s_p = - self.a_p / one_m_a
        self.var_x = 9.

    def at(self, position):
        return self.__class__(position, parameters=self.parameters)

    @property
    def value(self):
        diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
        point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u))
        det = self.s.integrate()
        det += 0.5 / self.var_x * self.position.vdot(self.position)
        return diffuse + point + det

    @property
    def gradient(self):
        diffuse = self.S.inverse(self.s) * self.s_p
        point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p
        det = self.position / self.var_x
        det += self.s_p
        return diffuse + point + det

    @property
    def curvature(self):
        point = self.q * exp(-self.u) * self.u_p ** 2
Martin Reinecke's avatar
Martin Reinecke committed
79
        R = self.FFT.inverse * self.s_p
Jakob Knollmueller's avatar
Jakob Knollmueller committed
80 81 82
        N = self.correlation
        S = DiagonalOperator(1/(point + 1/self.var_x))
        return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)