From e091563a73628174bdd2434ff72668dbb4e56b67 Mon Sep 17 00:00:00 2001 From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de> Date: Wed, 25 Apr 2018 16:22:00 +0200 Subject: [PATCH] enabling KL --- demos/KL_demo.py | 138 ++++++++++++++++++++++++++++++++++++++ demos/clipping.py | 91 +++++++++++++++++++++++++ demos/demo.py | 18 +++-- starblade/__init__.py | 2 + starblade/starblade_kl.py | 60 +++++++++++++++++ starblade/sugar.py | 38 ++++++++--- 6 files changed, 333 insertions(+), 14 deletions(-) create mode 100644 demos/KL_demo.py create mode 100644 demos/clipping.py create mode 100644 starblade/starblade_kl.py diff --git a/demos/KL_demo.py b/demos/KL_demo.py new file mode 100644 index 0000000..657ef6c --- /dev/null +++ b/demos/KL_demo.py @@ -0,0 +1,138 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2017-2018 Max-Planck-Society +# Author: Jakob Knollmueller +# +# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik + +import numpy as np +from astropy.io import fits +from matplotlib import pyplot as plt +from multiprocessing import Pool + +import nifty4 as ift +from nifty4.library.nonlinearities import PositiveTanh + + +import starblade as sb +from starblade.starblade_energy import StarbladeEnergy +from starblade.starblade_kl import StarbladeKL + +def power_update(KL_energy): + power = 0. + for energy in KL_energy.energy_list: + power += ift.power_analyze(FFT.inverse_times(energy.s), + binbounds=p_space.binbounds) + power /= len(KL_energy.energy_list) + return power + +if __name__ == '__main__': + #specifying location of the input file: + path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits' + path = 'data/frame-u-006174-2-0094.fits' + # path = 'data/frame-g-002821-6-0141.fits' + path = 'data/frame-g-007812-6-0100.fits' + path = 'data/frame-i-004874-3-0692.fits' + + # data = fits.open(path)[1].data + data = fits.open(path)[0].data#[1000:,1250:] + data -= data.min() - 0.001 + # data = np.exp(2*(1.-plt.imread('data/sdss.png').T[0])) + # data = (plt.imread('data/m51_3.jpg').T[0]) + # data = (plt.imread('data/12_FBP.png').T[0]) + + + # + # data = data.clip(min=0.001) + + + data = np.ndarray.astype(data, float) + vmin = np.log(data.min()+0.01) + vmax = np.log(data.max()) + plt.imsave('data.png', np.log(data)) + postanh=PositiveTanh() + alpha = 1.5 + s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1]) + h_space = s_space.get_default_codomain() + data = ift.Field(s_space,val=data) + FFT = ift.FFTOperator(h_space, target=s_space) + binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False) + p_space = ift.PowerSpace(h_space, binbounds=binbounds) + initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), + binbounds=p_space.binbounds) + initial_spectrum /= (p_space.k_lengths+1.)**4 + update_power = True + + initial_x = ift.Field(s_space, val=-1.) + alpha = ift.Field(s_space, val=alpha) + q = ift.Field(s_space, val=1e-30) + ICI = ift.GradientNormController(iteration_limit=100, + tol_abs_gradnorm=1e-3) + inverter = ift.ConjugateGradient(controller=ICI) + + parameters = dict(data=data, power_spectrum=initial_spectrum, + alpha=alpha, q=q, + inverter=inverter, FFT=FFT, + newton_iterations=5, update_power=update_power) + current_x = initial_x + for i in range(10): + Starblade = StarbladeEnergy(position=current_x, parameters=parameters) + samples = [] + for i in range(3): + sample = Starblade.curvature.inverse.draw_sample() + samples.append(sample) + problem = StarbladeKL(current_x, samples,parameters) + + controller = ift.GradientNormController(name="Newton", + tol_abs_gradnorm=1e-5, + iteration_limit=5) + minimizer = ift.RelaxedNewton(controller=controller) + problem, convergence = minimizer(problem) + current_x = problem.position + parameters['power_spectrum'] = power_update(problem) + Starblade = StarbladeEnergy(position=current_x, parameters=parameters) + + # Starblade = sb.build_starblade(data, alpha=alpha) + # for i in range(10): + # Starblade = sb.starblade_iteration(Starblade) + # + # #plotting on logarithmic scale + plt.imsave('diffuse_component.png', (Starblade.s).val,vmin=vmin, vmax=vmax) + plt.imsave('pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax) + Starblade = StarbladeEnergy(position=current_x, parameters=parameters) + var = 0. + mean = 0 + samps = 30 + for i in range(samps): + sam = postanh(Starblade.position+Starblade.curvature.inverse.draw_sample()) + mean += sam + var += sam**2 + + var /= samps + mean /= samps + var -= mean**2 + mask = ift.sqrt(var) < 0.01 +0. + plt.imsave('masked_points.png', mask.val * Starblade.u.val, vmin=vmin, vmax=vmax) + plt.imsave('masked_diffuse.png', mask.val * Starblade.s.val) + + plt.imsave('std.png', np.log(np.sqrt(var.val)*data.val), vmin=-3.3) + # plt.figure() + # k_lenghts = Starblade.power_spectrum.domain[0].k_lengths + # plt.plot(k_lenghts, Starblade.power_spectrum.val) + # plt.title('power spectrum') + # plt.yscale('log') + # plt.xscale('log') + # plt.ylabel('power') + # plt.xscale('harmonic mode') + # plt.savefig('power_spectrum.png') diff --git a/demos/clipping.py b/demos/clipping.py new file mode 100644 index 0000000..2e0c72a --- /dev/null +++ b/demos/clipping.py @@ -0,0 +1,91 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2017-2018 Max-Planck-Society +# Author: Jakob Knollmueller +# +# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik + +import numpy as np +from astropy.io import fits +from matplotlib import pyplot as plt +from scipy.ndimage.filters import median_filter +import starblade as sb + +if __name__ == '__main__': + #specifying location of the input file: + # path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits' + # data = fits.open(path)[1].data + path = 'data/frame-i-004874-3-0692.fits' + path ='data/check.fits' + # data = fits.open(path)[1].data + data = fits.open(path)[0].data[1000:,1250:] + data -= data.min() - 0.001 + data = data.clip(min=0.001) + + data_true = data.copy() + + data = np.ndarray.astype(data, float) + vmin = np.log(data.min()+0.01) + vmax = np.log(data.max()) + + local_size = 4 + for i in range(5): + for i in range(data.shape[0]/local_size): + for j in range(data.shape[1]/local_size): + local_data = data[i*local_size:(1+i)*local_size,j*local_size:(1+j)*local_size] + local_data_median = np.median(local_data) + local_data_var = local_data.var() + local_data = local_data.clip(min=local_data_median - 3*np.sqrt(local_data_var), + max=local_data_median + 3*np.sqrt(local_data_var)) + data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] = local_data + + + background = np.empty_like(data) + crowded = np.zeros_like(data) + for i in range(data.shape[0] / local_size): + for j in range(data.shape[1] / local_size): + local_true_data = data_true[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] + local_data = data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] + local_true_var = local_true_data.var() + local_var = local_data.var() + if 0.8 * np.sqrt(local_true_var) > np.sqrt(local_var): + background[i * local_size:(1 + i) * local_size, + j * local_size:(1 + j) * local_size] = 2.5*np.median(local_data)-1.5*local_data.mean() + crowded[i * local_size:(1 + i) * local_size, + j * local_size:(1 + j) * local_size] = 1. + else: + background[i * local_size:(1 + i) * local_size, + j * local_size:(1 + j) * local_size] = local_data.mean() + + background = median_filter(background, size=(local_size,local_size)) + # alpha = 1.25 + # Starblade = sb.build_starblade(data, alpha=alpha) + # for i in range(10): + # Starblade = sb.starblade_iteration(Starblade) + # + # plotting on logarithmic scale + # background += background.min() + plt.gray() + plt.imsave('diffuse_component.png', np.log(background))#, vmin=vmin, vmax=vmax) + plt.imsave('pointlike_component.png', (data_true - background), vmin=vmin, vmax=vmax) + plt.imsave('crowded.png',crowded) + # plt.figure() + # k_lenghts = Starblade.power_spectrum.domain[0].k_lengths + # plt.plot(k_lenghts, Starblade.power_spectrum.val) + # plt.title('power spectrum') + # plt.yscale('log') + # plt.xscale('log') + # plt.ylabel('power') + # plt.xscale('harmonic mode') + # plt.savefig('power_spectrum.png') diff --git a/demos/demo.py b/demos/demo.py index 1f850b7..ac3db51 100644 --- a/demos/demo.py +++ b/demos/demo.py @@ -25,18 +25,26 @@ import starblade as sb if __name__ == '__main__': #specifying location of the input file: path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits' - data = fits.open(path)[1].data + path = 'data/frame-i-004874-3-0692.fits' + + # data = fits.open(path)[1].data + data = fits.open(path)[0].data[1000:15000,1250:1750] + data -= data.min() - 0.001 + # data = 1.-plt.imread('data/sdss.png').T[0] + # data = fits.open(path)[1].data + + data = data.clip(min=0.0001) - data = data.clip(min=0.001) data = np.ndarray.astype(data, float) - vmin = np.log(data.min()+0.01) + vmin = np.log(data.min()+0.2) vmax = np.log(data.max()) + plt.imsave('data.png', np.log(data),vmin=vmin,vmax=vmax) alpha = 1.25 Starblade = sb.build_starblade(data, alpha=alpha) for i in range(10): - Starblade = sb.starblade_iteration(Starblade) + Starblade = sb.starblade_iteration(Starblade, samples=i) #plotting on logarithmic scale plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax) @@ -48,5 +56,5 @@ if __name__ == '__main__': plt.yscale('log') plt.xscale('log') plt.ylabel('power') - plt.xscale('harmonic mode') + plt.xlabel('harmonic mode') plt.savefig('power_spectrum.png') diff --git a/starblade/__init__.py b/starblade/__init__.py index f86f342..8fff660 100644 --- a/starblade/__init__.py +++ b/starblade/__init__.py @@ -1,2 +1,4 @@ from .sugar import (build_starblade, starblade_iteration, build_multi_starblade, multi_starblade_iteration) +from .starblade_kl import StarbladeKL +from .starblade_energy import StarbladeEnergy diff --git a/starblade/starblade_kl.py b/starblade/starblade_kl.py new file mode 100644 index 0000000..8d44fd6 --- /dev/null +++ b/starblade/starblade_kl.py @@ -0,0 +1,60 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2017-2018 Max-Planck-Society +# Author: Jakob Knollmueller +# +# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik + +from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler +from starblade_energy import StarbladeEnergy + +class StarbladeKL(Energy): + + def __init__(self, position, samples, parameters): + super(StarbladeKL, self).__init__(position=position) + self.samples = samples + self.parameters = parameters + self.energy_list=[] + for sample in samples: + energy = StarbladeEnergy(position+sample,parameters) + self.energy_list.append(energy) + + + def at(self, position): + return self.__class__(position, samples=self.samples, parameters=self.parameters) + + @property + def value(self): + value = 0. + for energy in self.energy_list: + value += energy.value + value /= len(self.energy_list) + return value + + @property + def gradient(self): + gradient = Field.zeros(self.position.domain) + for energy in self.energy_list: + gradient += energy.gradient + gradient /= len(self.energy_list) + return gradient + + @property + def curvature(self): + curvature = DiagonalOperator(Field.zeros(self.position.domain)) + for energy in self.energy_list: + curvature += energy.curvature + curvature *= Field(self.position.domain,val=1./len(self.energy_list)) + return InversionEnabler(curvature, self.parameters['inverter']) + diff --git a/starblade/sugar.py b/starblade/sugar.py index 7a235c8..f77ccaa 100644 --- a/starblade/sugar.py +++ b/starblade/sugar.py @@ -21,9 +21,9 @@ from multiprocessing import Pool import nifty4 as ift from .starblade_energy import StarbladeEnergy +from .starblade_kl import StarbladeKL - -def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iterations = 3, +def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iterations = 3, manual_power_spectrum = None): """ Setting up the StarbladeEnergy for the given data and parameters Parameters @@ -69,9 +69,12 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iteratio inverter=inverter, FFT=FFT, newton_iterations=newton_iterations, update_power=update_power) Starblade = StarbladeEnergy(position=initial_x, parameters=parameters) + + return Starblade -def starblade_iteration(starblade): + +def starblade_iteration(starblade, samples=3): """ Performing one Newton minimization step Parameters ---------- @@ -82,14 +85,19 @@ def starblade_iteration(starblade): tol_abs_gradnorm=1e-8, iteration_limit=starblade.newton_iterations) minimizer = ift.RelaxedNewton(controller=controller) - energy, convergence = minimizer(starblade) + sample_list = [] + for i in range(samples): + sample = starblade.curvature.inverse.draw_sample() + sample_list.append(sample) + if len(sample_list)>0: + energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters) + else: + energy = starblade + energy, convergence = minimizer(energy) new_position = energy.position new_parameters = energy.parameters - if energy.update_power: - h_space = energy.correlation.domain[0] - FFT = energy.FFT - binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False) - new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds) + if energy.parameters['update_power']: + new_power = update_power(energy) # new_power /= (new_power.domain[0].k_lengths+1.)**2 new_parameters['power_spectrum'] = new_power @@ -143,6 +151,18 @@ def multi_starblade_iteration(MultiStarblade, processes = 1): NewStarblades.append(starblade_iteration(starblade)) return NewStarblades +def update_power(energy): + if isinstance(energy, StarbladeKL): + power = 0. + for en in energy.energy_list: + power += ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s), + binbounds=en.parameters['power_spectrum'].domain[0].binbounds) + power /= len(energy.energy_list) + else: + power = ift.power_analyze(energy.FFT.inverse_times(energy.s), + binbounds=energy.parameters['power_spectrum'].domain[0].binbounds) + return power + if __name__ == '__main__': pass -- GitLab