diff --git a/starblade/starblade_energy.py b/starblade/starblade_energy.py index e9eaae51c20e6d1e9902bd0663e3db1927c0b63d..6bae0583f1493d4e8c4328c9c208af3ba2445252 100644 --- a/starblade/starblade_energy.py +++ b/starblade/starblade_energy.py @@ -1,8 +1,47 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2017-2018 Max-Planck-Society +# Author: Jakob Knollmueller +# +# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik + import nifty4 as ift import numpy as np class StarbladeEnergy(ift.Energy): + """The Energy for the starblade problem. + + It implements the Information Hamiltonian of the separation of d + + Parameters + ---------- + position : Field + The current position of the separation. + data : Field + The image data. + alpha : float + Slope parameter of the point-source prior + q : float + Cutoff parameter of the point-source prior + correlation : Field + A field in the Fourier space which encodes the diagonal of the prior + correlation structure of the diffuse component + inverter : ConjugateGradient (optional) + the minimization strategy to use for operator inversion + If None, the energy will not be able to compute curvatures + """ def __init__(self, position, data, alpha, q, correlation, inverter=None): if (position > 9.).any() or (position < -9.).any(): raise ValueError("position outside allowed range") @@ -22,8 +61,8 @@ class StarbladeEnergy(ift.Energy): self._ptanh = ift.library.PositiveTanh() a = self._ptanh(position) a_p = self._ptanh.derivative(position) - one_m_a = 1. - a - s = ift.log(data * one_m_a) + one_m_a = 1.-a + s = ift.log(data*one_m_a) if correlation is None: binbounds = ift.PowerSpace.useful_binbounds(h_space, @@ -33,13 +72,13 @@ class StarbladeEnergy(ift.Energy): correlation = ift.create_power_operator(h_space, correlation) self._correlation = correlation - S = FFT * correlation * FFT.adjoint + S = FFT*correlation*FFT.adjoint - usum = ift.log(data * a).sum() + usum = ift.log(data*a).sum() var_x = 9. Sis = S.inverse(s) - qexpmu = self._q / (data*a) + qexpmu = self._q/(data*a) diffuse = 0.5 * s.vdot(Sis) point = (alpha-1)*usum + qexpmu.sum() @@ -56,7 +95,7 @@ class StarbladeEnergy(ift.Energy): self._gradient = (diffuse + point + det).lock() if inverter is not None: # curvature is needed - point = qexpmu * u_p ** 2 + point = qexpmu * u_p**2 R = FFT.inverse * s_p N = self._correlation S = ift.DiagonalOperator(1./(point + 1./var_x)) diff --git a/starblade/sugar.py b/starblade/sugar.py index 33c88a744af529090b423bf846763f9ac3c6a92a..c91e0cdfe360d54e6138be32496a1945d57101bc 100644 --- a/starblade/sugar.py +++ b/starblade/sugar.py @@ -38,12 +38,19 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1]) data = ift.Field.from_global_data(s_space, data) + h_space = s_space.get_default_codomain() + FFT = ift.FFTOperator(h_space, s_space) + binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False) + correlation = ift.power_analyze(FFT.inverse_times(ift.log(data)), + binbounds=binbounds) + correlation /= (ift.PowerSpace(h_space, binbounds).k_lengths+1.)**2 + correlation = ift.create_power_operator(h_space, correlation) #if BFGS: # return StarbladeEnergy.make(ift.Field.full(s_space, -1.), data, alpha) ICI = ift.GradientNormController(iteration_limit=cg_iterations, tol_abs_gradnorm=1e-5) inverter = ift.ConjugateGradient(controller=ICI) - return StarbladeEnergy.make(ift.Field.full(s_space, -1.), data, alpha, q, + return StarbladeEnergy(ift.Field.full(s_space, -1.), data, alpha, q, correlation, inverter) @@ -64,7 +71,9 @@ def starblade_iteration(starblade, iterations=3): #else: minimizer = ift.RelaxedNewton(controller=controller) starblade, convergence = minimizer(starblade) - return starblade.with_new_correlation() + # FIXME: this is not final yet! + return starblade + #return starblade.with_new_correlation() def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): @@ -82,7 +91,8 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500). """ - return [build_starblade(data[..., i].copy(), alpha, q, cg_iterations) for i in range(data.shape[-1])] + return [build_starblade(data[..., i].copy(), alpha, q, cg_iterations) + for i in range(data.shape[-1])] def multi_starblade_iteration(MultiStarblade, multiprocessing=False):