starblade_energy.py 3.63 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2017-2018 Max-Planck-Society
# Author: Jakob Knollmueller
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik

Martin Reinecke's avatar
Martin Reinecke committed
19 20
from nifty4 import Energy, Field, log, exp, DiagonalOperator
from nifty4.library import WienerFilterCurvature
Jakob Knollmueller's avatar
Jakob Knollmueller committed
21
from nifty4.library.nonlinearities import PositiveTanh
Jakob Knollmueller's avatar
Jakob Knollmueller committed
22 23


Jakob Knollmueller's avatar
Jakob Knollmueller committed
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
class StarbladeEnergy(Energy):
    """The Energy for the starblade problem.

    It implements the Information Hamiltonian of the separation of d

    Parameters
    ----------
    position : Field
        The current position of the separation.
    parameters : Dictionary
        Dictionary containing all relevant quantities for the inference,
        data : Field
            The image data.
        alpha : Field
            Slope parameter of the point-source prior
        q : Field
            Cutoff parameter of the point-source prior
        correlation : Field
            A field in the Fourier space which encodes the diagonal of the prior
            correlation structure of the diffuse component
        FFT : FFTOperator
            An operator performing the Fourier transform
        inverter : ConjugateGradient
            the minimization strategy to use for operator inversion
    """
Jakob Knollmueller's avatar
Jakob Knollmueller committed
49 50 51 52 53

    def __init__(self, position, parameters):

        x = position.val.clip(-9, 9)
        position = Field(position.domain, val=x)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
54
        super(StarbladeEnergy, self).__init__(position=position)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
55 56 57 58 59 60 61 62

        self.parameters = parameters
        self.inverter = parameters['inverter']
        self.d = parameters['data']
        self.FFT = parameters['FFT']
        self.correlation = parameters['correlation']
        self.alpha = parameters['alpha']
        self.q = parameters['q']
Jakob Knollmueller's avatar
Jakob Knollmueller committed
63
        pos_tanh = PositiveTanh()
Martin Reinecke's avatar
Martin Reinecke committed
64
        self.S = self.FFT * self.correlation * self.FFT.adjoint
Jakob Knollmueller's avatar
Jakob Knollmueller committed
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        self.a = pos_tanh(self.position)
        self.a_p = pos_tanh.derivative(self.position)

        self.u = log(self.d * self.a)
        self.u_p = self.a_p/self.a
        one_m_a = 1 - self.a
        self.s = log(self.d * one_m_a)
        self.s_p = - self.a_p / one_m_a
        self.var_x = 9.

    def at(self, position):
        return self.__class__(position, parameters=self.parameters)

    @property
    def value(self):
        diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
        point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u))
        det = self.s.integrate()
        det += 0.5 / self.var_x * self.position.vdot(self.position)
        return diffuse + point + det

    @property
    def gradient(self):
        diffuse = self.S.inverse(self.s) * self.s_p
        point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p
        det = self.position / self.var_x
        det += self.s_p
        return diffuse + point + det

    @property
    def curvature(self):
        point = self.q * exp(-self.u) * self.u_p ** 2
Martin Reinecke's avatar
Martin Reinecke committed
97
        R = self.FFT.inverse * self.s_p
Jakob Knollmueller's avatar
Jakob Knollmueller committed
98 99 100
        N = self.correlation
        S = DiagonalOperator(1/(point + 1/self.var_x))
        return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)