# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2017-2018 Max-Planck-Society # Author: Jakob Knollmueller # # Starblade is being developed at the Max-Planck-Institut fuer Astrophysik from nifty4 import Energy, Field, log, exp, DiagonalOperator from nifty4.library import WienerFilterCurvature from nifty4.library.nonlinearities import PositiveTanh class StarbladeEnergy(Energy): """The Energy for the starblade problem. It implements the Information Hamiltonian of the separation of d Parameters ---------- position : Field The current position of the separation. parameters : Dictionary Dictionary containing all relevant quantities for the inference, data : Field The image data. alpha : Field Slope parameter of the point-source prior q : Field Cutoff parameter of the point-source prior correlation : Field A field in the Fourier space which encodes the diagonal of the prior correlation structure of the diffuse component FFT : FFTOperator An operator performing the Fourier transform inverter : ConjugateGradient the minimization strategy to use for operator inversion """ def __init__(self, position, parameters): x = position.val.clip(-9, 9) position = Field(position.domain, val=x) super(StarbladeEnergy, self).__init__(position=position) self.parameters = parameters self.inverter = parameters['inverter'] self.d = parameters['data'] self.FFT = parameters['FFT'] self.correlation = parameters['correlation'] self.alpha = parameters['alpha'] self.q = parameters['q'] pos_tanh = PositiveTanh() self.S = self.FFT * self.correlation * self.FFT.adjoint self.a = pos_tanh(self.position) self.a_p = pos_tanh.derivative(self.position) self.u = log(self.d * self.a) self.u_p = self.a_p/self.a one_m_a = 1 - self.a self.s = log(self.d * one_m_a) self.s_p = - self.a_p / one_m_a self.var_x = 9. def at(self, position): return self.__class__(position, parameters=self.parameters) @property def value(self): diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s)) point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u)) det = self.s.integrate() det += 0.5 / self.var_x * self.position.vdot(self.position) return diffuse + point + det @property def gradient(self): diffuse = self.S.inverse(self.s) * self.s_p point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p det = self.position / self.var_x det += self.s_p return diffuse + point + det @property def curvature(self): point = self.q * exp(-self.u) * self.u_p ** 2 R = self.FFT.inverse * self.s_p N = self.correlation S = DiagonalOperator(1/(point + 1/self.var_x)) return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)