diff --git a/demos/KL_demo.py b/demos/KL_demo.py deleted file mode 100644 index 657ef6ce7fce7ae042c569df5016d434c4c1c343..0000000000000000000000000000000000000000 --- a/demos/KL_demo.py +++ /dev/null @@ -1,138 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2017-2018 Max-Planck-Society -# Author: Jakob Knollmueller -# -# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik - -import numpy as np -from astropy.io import fits -from matplotlib import pyplot as plt -from multiprocessing import Pool - -import nifty4 as ift -from nifty4.library.nonlinearities import PositiveTanh - - -import starblade as sb -from starblade.starblade_energy import StarbladeEnergy -from starblade.starblade_kl import StarbladeKL - -def power_update(KL_energy): - power = 0. - for energy in KL_energy.energy_list: - power += ift.power_analyze(FFT.inverse_times(energy.s), - binbounds=p_space.binbounds) - power /= len(KL_energy.energy_list) - return power - -if __name__ == '__main__': - #specifying location of the input file: - path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits' - path = 'data/frame-u-006174-2-0094.fits' - # path = 'data/frame-g-002821-6-0141.fits' - path = 'data/frame-g-007812-6-0100.fits' - path = 'data/frame-i-004874-3-0692.fits' - - # data = fits.open(path)[1].data - data = fits.open(path)[0].data#[1000:,1250:] - data -= data.min() - 0.001 - # data = np.exp(2*(1.-plt.imread('data/sdss.png').T[0])) - # data = (plt.imread('data/m51_3.jpg').T[0]) - # data = (plt.imread('data/12_FBP.png').T[0]) - - - # - # data = data.clip(min=0.001) - - - data = np.ndarray.astype(data, float) - vmin = np.log(data.min()+0.01) - vmax = np.log(data.max()) - plt.imsave('data.png', np.log(data)) - postanh=PositiveTanh() - alpha = 1.5 - s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1]) - h_space = s_space.get_default_codomain() - data = ift.Field(s_space,val=data) - FFT = ift.FFTOperator(h_space, target=s_space) - binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False) - p_space = ift.PowerSpace(h_space, binbounds=binbounds) - initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), - binbounds=p_space.binbounds) - initial_spectrum /= (p_space.k_lengths+1.)**4 - update_power = True - - initial_x = ift.Field(s_space, val=-1.) - alpha = ift.Field(s_space, val=alpha) - q = ift.Field(s_space, val=1e-30) - ICI = ift.GradientNormController(iteration_limit=100, - tol_abs_gradnorm=1e-3) - inverter = ift.ConjugateGradient(controller=ICI) - - parameters = dict(data=data, power_spectrum=initial_spectrum, - alpha=alpha, q=q, - inverter=inverter, FFT=FFT, - newton_iterations=5, update_power=update_power) - current_x = initial_x - for i in range(10): - Starblade = StarbladeEnergy(position=current_x, parameters=parameters) - samples = [] - for i in range(3): - sample = Starblade.curvature.inverse.draw_sample() - samples.append(sample) - problem = StarbladeKL(current_x, samples,parameters) - - controller = ift.GradientNormController(name="Newton", - tol_abs_gradnorm=1e-5, - iteration_limit=5) - minimizer = ift.RelaxedNewton(controller=controller) - problem, convergence = minimizer(problem) - current_x = problem.position - parameters['power_spectrum'] = power_update(problem) - Starblade = StarbladeEnergy(position=current_x, parameters=parameters) - - # Starblade = sb.build_starblade(data, alpha=alpha) - # for i in range(10): - # Starblade = sb.starblade_iteration(Starblade) - # - # #plotting on logarithmic scale - plt.imsave('diffuse_component.png', (Starblade.s).val,vmin=vmin, vmax=vmax) - plt.imsave('pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax) - Starblade = StarbladeEnergy(position=current_x, parameters=parameters) - var = 0. - mean = 0 - samps = 30 - for i in range(samps): - sam = postanh(Starblade.position+Starblade.curvature.inverse.draw_sample()) - mean += sam - var += sam**2 - - var /= samps - mean /= samps - var -= mean**2 - mask = ift.sqrt(var) < 0.01 +0. - plt.imsave('masked_points.png', mask.val * Starblade.u.val, vmin=vmin, vmax=vmax) - plt.imsave('masked_diffuse.png', mask.val * Starblade.s.val) - - plt.imsave('std.png', np.log(np.sqrt(var.val)*data.val), vmin=-3.3) - # plt.figure() - # k_lenghts = Starblade.power_spectrum.domain[0].k_lengths - # plt.plot(k_lenghts, Starblade.power_spectrum.val) - # plt.title('power spectrum') - # plt.yscale('log') - # plt.xscale('log') - # plt.ylabel('power') - # plt.xscale('harmonic mode') - # plt.savefig('power_spectrum.png') diff --git a/demos/clipping.py b/demos/clipping.py deleted file mode 100644 index 2e0c72a678f8aa1b2b9ad74a8f2e61947bb17086..0000000000000000000000000000000000000000 --- a/demos/clipping.py +++ /dev/null @@ -1,91 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2017-2018 Max-Planck-Society -# Author: Jakob Knollmueller -# -# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik - -import numpy as np -from astropy.io import fits -from matplotlib import pyplot as plt -from scipy.ndimage.filters import median_filter -import starblade as sb - -if __name__ == '__main__': - #specifying location of the input file: - # path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits' - # data = fits.open(path)[1].data - path = 'data/frame-i-004874-3-0692.fits' - path ='data/check.fits' - # data = fits.open(path)[1].data - data = fits.open(path)[0].data[1000:,1250:] - data -= data.min() - 0.001 - data = data.clip(min=0.001) - - data_true = data.copy() - - data = np.ndarray.astype(data, float) - vmin = np.log(data.min()+0.01) - vmax = np.log(data.max()) - - local_size = 4 - for i in range(5): - for i in range(data.shape[0]/local_size): - for j in range(data.shape[1]/local_size): - local_data = data[i*local_size:(1+i)*local_size,j*local_size:(1+j)*local_size] - local_data_median = np.median(local_data) - local_data_var = local_data.var() - local_data = local_data.clip(min=local_data_median - 3*np.sqrt(local_data_var), - max=local_data_median + 3*np.sqrt(local_data_var)) - data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] = local_data - - - background = np.empty_like(data) - crowded = np.zeros_like(data) - for i in range(data.shape[0] / local_size): - for j in range(data.shape[1] / local_size): - local_true_data = data_true[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] - local_data = data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] - local_true_var = local_true_data.var() - local_var = local_data.var() - if 0.8 * np.sqrt(local_true_var) > np.sqrt(local_var): - background[i * local_size:(1 + i) * local_size, - j * local_size:(1 + j) * local_size] = 2.5*np.median(local_data)-1.5*local_data.mean() - crowded[i * local_size:(1 + i) * local_size, - j * local_size:(1 + j) * local_size] = 1. - else: - background[i * local_size:(1 + i) * local_size, - j * local_size:(1 + j) * local_size] = local_data.mean() - - background = median_filter(background, size=(local_size,local_size)) - # alpha = 1.25 - # Starblade = sb.build_starblade(data, alpha=alpha) - # for i in range(10): - # Starblade = sb.starblade_iteration(Starblade) - # - # plotting on logarithmic scale - # background += background.min() - plt.gray() - plt.imsave('diffuse_component.png', np.log(background))#, vmin=vmin, vmax=vmax) - plt.imsave('pointlike_component.png', (data_true - background), vmin=vmin, vmax=vmax) - plt.imsave('crowded.png',crowded) - # plt.figure() - # k_lenghts = Starblade.power_spectrum.domain[0].k_lengths - # plt.plot(k_lenghts, Starblade.power_spectrum.val) - # plt.title('power spectrum') - # plt.yscale('log') - # plt.xscale('log') - # plt.ylabel('power') - # plt.xscale('harmonic mode') - # plt.savefig('power_spectrum.png') diff --git a/starblade/starblade_kl.py b/starblade/starblade_kl.py index 8d44fd6ab1dcad2e679cbe3e7cd0dece0e72574e..ff05ce9353bdacd6a9200cbbea66416b66adf93d 100644 --- a/starblade/starblade_kl.py +++ b/starblade/starblade_kl.py @@ -20,6 +20,30 @@ from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler from starblade_energy import StarbladeEnergy class StarbladeKL(Energy): + """The Kullback-Leibler divergence for the starblade problem. + + Parameters + ---------- + position : Field + The current position of the separation. + samples : List + A list containing residual samples. + parameters : Dictionary + Dictionary containing all relevant quantities for the inference, + data : Field + The image data. + alpha : Field + Slope parameter of the point-source prior + q : Field + Cutoff parameter of the point-source prior + power_spectrum : callable or Field + An object that contains the power spectrum of the diffuse component + as a function of the harmonic mode. + FFT : FFTOperator + An operator performing the Fourier transform + inverter : ConjugateGradient + the minimization strategy to use for operator inversion + """ def __init__(self, position, samples, parameters): super(StarbladeKL, self).__init__(position=position) diff --git a/starblade/sugar.py b/starblade/sugar.py index f77ccaaaf026a4693aabfad1bbd4677a48be2748..4c9253cb6072321dcad76718c2a3690ab6aab253 100644 --- a/starblade/sugar.py +++ b/starblade/sugar.py @@ -80,6 +80,8 @@ def starblade_iteration(starblade, samples=3): ---------- starblade : StarbladeEnergy An instance of an Starblade Energy + samples : int + Number of samples drawn in order to estimate the KL. If zero the MAP is calculated (default: 3). """ controller = ift.GradientNormController(name="Newton", tol_abs_gradnorm=1e-8, @@ -152,6 +154,13 @@ def multi_starblade_iteration(MultiStarblade, processes = 1): return NewStarblades def update_power(energy): + """ Calculates a new estimate of the power spectrum given a StarbladeEnergy or StarbladeKL. + For Energy the MAP estimate of the power spectrum is calculated and for KL the variational estimate. + ---------- + energy : StarbladeEnergy or StarbladeKL + An instance of an StarbladeEnergy or StarbladeKL + + """ if isinstance(energy, StarbladeKL): power = 0. for en in energy.energy_list: