From c8ff706777c26676ff86989081b24cdfbf04aac1 Mon Sep 17 00:00:00 2001 From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de> Date: Tue, 10 Apr 2018 14:32:29 +0200 Subject: [PATCH] documentation, cleanup --- 1d_separation.py | 2 +- demo.py | 25 ++++++++++ gui_app.py | 2 +- hubble_separation.py | 12 +++-- multichannel_demo.py | 28 +++++++++++ point_separation.py | 74 ----------------------------- rgb_separation.py | 25 ---------- separation_energy.py | 32 +++++++++++-- sugar.py | 111 +++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 203 insertions(+), 108 deletions(-) create mode 100644 demo.py create mode 100644 multichannel_demo.py delete mode 100644 point_separation.py delete mode 100644 rgb_separation.py create mode 100644 sugar.py diff --git a/1d_separation.py b/1d_separation.py index f90f0dd..653b090 100644 --- a/1d_separation.py +++ b/1d_separation.py @@ -1,4 +1,4 @@ -from point_separation import build_problem, problem_iteration +from sugar import build_problem, problem_iteration import nifty4 as ift import numpy as np from matplotlib import rc diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..1adbbb6 --- /dev/null +++ b/demo.py @@ -0,0 +1,25 @@ +from sugar import build_starblade, starblade_iteration +from matplotlib import pyplot as plt +from astropy.io import fits + +import numpy as np + +if __name__ == '__main__': + #specifying location of the input file: + path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits' + data = fits.open(path)[1].data + + data = data.clip(min=0.001) + + data = np.ndarray.astype(data, float) + vmin = np.log(data.min()+0.01) + vmax = np.log(data.max()) + + alpha = 1.25 + Starblade = build_starblade(data, alpha=alpha) + for i in range(10): + Starblade = starblade_iteration(Starblade) + + #plotting on logarithmic scale + plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax) + plt.imsave('pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax) \ No newline at end of file diff --git a/gui_app.py b/gui_app.py index 6972990..61aa4bf 100644 --- a/gui_app.py +++ b/gui_app.py @@ -2,7 +2,7 @@ import matplotlib matplotlib.use('agg') # matplotlib.use('module://kivy.garden.matplotlib.backend_kivy') -from point_separation import build_multi_problem, multi_problem_iteration,load_data +from sugar import build_multi_problem, multi_problem_iteration,load_data from kivy.app import App from kivy.uix.widget import Widget diff --git a/hubble_separation.py b/hubble_separation.py index 7d3c435..f31c09f 100644 --- a/hubble_separation.py +++ b/hubble_separation.py @@ -1,4 +1,4 @@ -from point_separation import build_problem, problem_iteration, load_data +from sugar import build_problem, problem_iteration, load_data from nifty4 import * import numpy as np from matplotlib import rc @@ -8,14 +8,20 @@ from matplotlib import pyplot as plt from matplotlib.colors import LogNorm from mpl_toolkits.axes_grid1 import make_axes_locatable from mpl_toolkits.axes_grid1 import AxesGrid +from astropy.io import fits + np.random.seed(42) if __name__ == '__main__': path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits' - data = load_data(path) - alpha = 1.3 + data = fits.open(path)[1].data + + data = data.clip(min=0.001) + + data = np.ndarray.astype(data, float) + alpha = 1.25 diff --git a/multichannel_demo.py b/multichannel_demo.py new file mode 100644 index 0000000..52ffcd8 --- /dev/null +++ b/multichannel_demo.py @@ -0,0 +1,28 @@ +from sugar import build_multi_starblade, multi_starblade_iteration +from matplotlib import pyplot as plt +import numpy as np + +if __name__ == '__main__': + + # data = plt.imread('10Keso1242a.tif') + data = plt.imread('eso1242a.jpg') + + data = data.astype(float) + data = data.clip(0.0001) + alpha = 1.25 + MultiStarblade = build_multi_starblade(data, alpha) + + for i in range(2): + MultiStarblade = multi_starblade_iteration(MultiStarblade, multiprocessing=True) + + #plotting a three channel RGB image in each iteration + diffuse = np.empty_like(data) + point = np.empty_like(data) + for i in range(len(MultiStarblade)): + diffuse[...,i] = np.exp(MultiStarblade[i].s.val) + point[...,i] = np.exp(MultiStarblade[i].u.val) + + plt.imsave('rgb_diffuse.jpg',diffuse/255.) + plt.imsave('rgb_point.jpg',point/255.) + + diff --git a/point_separation.py b/point_separation.py deleted file mode 100644 index 0746292..0000000 --- a/point_separation.py +++ /dev/null @@ -1,74 +0,0 @@ -import nifty4 as ift -import numpy as np -from matplotlib import pyplot as plt -from astropy.io import fits -from separation_energy import SeparationEnergy -from nifty4.library.nonlinearities import PositiveTanh - -def load_data(path): - - if path[-5:] == '.fits': - data = fits.open(path)[1].data - else: - data = plt.imread(path)[:,:,0] - - data = data.clip(min=0.001) - data = np.ndarray.astype(data, float) - return data - -def build_problem(data, alpha): - s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1]) - h_space = s_space.get_default_codomain() - data = ift.Field(s_space,val=data) - FFT = ift.FFTOperator(h_space) - binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False) - p_space = ift.PowerSpace(h_space, binbounds=binbounds) - initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), binbounds=p_space.binbounds) - initial_correlation = ift.create_power_operator(h_space, initial_spectrum) - initial_x = ift.Field(s_space, val=-1.) - alpha = ift.Field(s_space, val=alpha) - q = ift.Field(s_space, val=10e-40) - pos_tanh = PositiveTanh() - ICI = ift.GradientNormController(iteration_limit=500, - tol_abs_gradnorm=1e-5) - inverter = ift.ConjugateGradient(controller=ICI) - - parameters = dict(data=data, correlation=initial_correlation, - alpha=alpha, q=q, - inverter=inverter, FFT=FFT, pos_tanh=pos_tanh) - separationEnergy = SeparationEnergy(position=initial_x, parameters=parameters) - return separationEnergy - -def problem_iteration(energy, iterations=3): - controller = ift.GradientNormController(name="test1", tol_abs_gradnorm=0.00000001, iteration_limit=iterations) - minimizer = ift.RelaxedNewton(controller=controller) - energy, convergence = minimizer(energy) - new_position = energy.position - h_space = energy.correlation.domain[0] - FFT = energy.FFT - binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False) - new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds) - new_correlation = ift.create_power_operator(h_space, new_power) - new_parameters = energy.parameters - new_parameters['correlation'] = new_correlation - new_energy = SeparationEnergy(new_position, new_parameters) - return new_energy - -def build_multi_problem(data, alpha): - energy_list = [] - for i in range(data.shape[-1]): - energy = build_problem(data[...,i],alpha) - energy_list.append(energy) - return energy_list - -def multi_problem_iteration(energy_list): - new_energy = [] - for energy in energy_list: - new_energy.append(problem_iteration(energy)) - return new_energy - -if __name__ == '__main__': - pass - - - diff --git a/rgb_separation.py b/rgb_separation.py deleted file mode 100644 index 2eb14e1..0000000 --- a/rgb_separation.py +++ /dev/null @@ -1,25 +0,0 @@ -from point_separation import build_multi_problem, multi_problem_iteration -from matplotlib import pyplot as plt -import numpy as np - -if __name__ == '__main__': - # data = plt.imread('eso1242a.jpg') - data = plt.imread('10Keso1242a.tif') - data = data.astype(float) - data = data.clip(0.0001) - energy_list = build_multi_problem(data, 1.2) - - for i in range(10): - energy_list = multi_problem_iteration(energy_list) - - - diffuse = np.empty_like(data) - point = np.empty_like(data) - for i in range(len(energy_list)): - diffuse[...,i] = np.exp(energy_list[i].s.val) - point[...,i] = np.exp(energy_list[i].u.val) - - plt.imsave('rgb_diffuse.jpg',diffuse/255.) - plt.imsave('rgb_point.jpg',point/255.) - - diff --git a/separation_energy.py b/separation_energy.py index 76d7810..a1ca6ca 100644 --- a/separation_energy.py +++ b/separation_energy.py @@ -1,14 +1,39 @@ from nifty4 import Energy, Field, log, exp, DiagonalOperator from nifty4.library import WienerFilterCurvature +from nifty4.library.nonlinearities import PositiveTanh -class SeparationEnergy(Energy): +class StarbladeEnergy(Energy): + """The Energy for the starblade problem. + + It implements the Information Hamiltonian of the separation of d + + Parameters + ---------- + position : Field + The current position of the separation. + parameters : Dictionary + Dictionary containing all relevant quantities for the inference, + data : Field + The image data. + alpha : Field + Slope parameter of the point-source prior + q : Field + Cutoff parameter of the point-source prior + correlation : Field + A field in the Fourier space which encodes the diagonal of the prior + correlation structure of the diffuse component + FFT : FFTOperator + An operator performing the Fourier transform + inverter : ConjugateGradient + the minimization strategy to use for operator inversion + """ def __init__(self, position, parameters): x = position.val.clip(-9, 9) position = Field(position.domain, val=x) - super(SeparationEnergy, self).__init__(position=position) + super(StarbladeEnergy, self).__init__(position=position) self.parameters = parameters self.inverter = parameters['inverter'] @@ -17,8 +42,7 @@ class SeparationEnergy(Energy): self.correlation = parameters['correlation'] self.alpha = parameters['alpha'] self.q = parameters['q'] - pos_tanh = parameters['pos_tanh'] - + pos_tanh = PositiveTanh() self.S = self.FFT * self.correlation * self.FFT.adjoint self.a = pos_tanh(self.position) self.a_p = pos_tanh.derivative(self.position) diff --git a/sugar.py b/sugar.py new file mode 100644 index 0000000..3c758c2 --- /dev/null +++ b/sugar.py @@ -0,0 +1,111 @@ +import nifty4 as ift +import numpy as np +from matplotlib import pyplot as plt +from multiprocessing import Pool +from separation_energy import StarbladeEnergy + + + +def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): + """ Setting up the StarbladeEnergy for the given data and parameters + Parameters + ---------- + data : array + The data in a numpy array + alpha : float + The slope parameter of the point source prior (default: 1.5). + q : float + The cutoff parameter of the point source prior (default: 1e-40). + cg_iterations : int + Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500). + """ + + s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1]) + h_space = s_space.get_default_codomain() + data = ift.Field(s_space,val=data) + FFT = ift.FFTOperator(h_space, target=s_space) + binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False) + p_space = ift.PowerSpace(h_space, binbounds=binbounds) + initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), binbounds=p_space.binbounds) + initial_spectrum /= (p_space.k_lengths+1.)**2 + initial_correlation = ift.create_power_operator(h_space, initial_spectrum) + initial_x = ift.Field(s_space, val=-1.) + alpha = ift.Field(s_space, val=alpha) + q = ift.Field(s_space, val=q) + ICI = ift.GradientNormController(iteration_limit=cg_iterations, + tol_abs_gradnorm=1e-5) + inverter = ift.ConjugateGradient(controller=ICI) + + parameters = dict(data=data, correlation=initial_correlation, + alpha=alpha, q=q, + inverter=inverter, FFT=FFT) + Starblade = StarbladeEnergy(position=initial_x, parameters=parameters) + return Starblade + +def starblade_iteration(starblade, iterations=3): + """ Performing one Newton minimization step + Parameters + ---------- + starblade : StarbladeEnergy + An instance of an Starblade Energy + iterations : int + The number of steps with the Newton scheme (default: 3). + """ + controller = ift.GradientNormController(name="Newton", tol_abs_gradnorm=1e-8, iteration_limit=iterations) + minimizer = ift.RelaxedNewton(controller=controller) + energy, convergence = minimizer(starblade) + new_position = energy.position + h_space = energy.correlation.domain[0] + FFT = energy.FFT + binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False) + new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds) + # new_power /= (new_power.domain[0].k_lengths+1.)**2 + + new_correlation = ift.create_power_operator(h_space, new_power) + new_parameters = energy.parameters + # new_parameters['correlation'] = new_correlation + NewStarblade = StarbladeEnergy(new_position, new_parameters) + return NewStarblade + +def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): + """ Builds a list of StarbladeEnergies for the given multi-channel dataset + Parameters + ---------- + data : array + The data in a numpy array of the multi-channel dataset with channel axis data[-1]. + alpha : float + The slope parameter of the point source prior (default: 1.5). + q : float + The cutoff parameter of the point source prior (default: 1e-40). + cg_iterations : int + Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500). + """ + MultiStarblade = [] + for i in range(data.shape[-1]): + starblade = build_starblade(data[...,i],alpha=alpha, q=q, cg_iterations=cg_iterations) + MultiStarblade.append(starblade) + return MultiStarblade + +def multi_starblade_iteration(MultiStarblade, multiprocessing = False): + """ Performing one Newton minimization step for all entries of the MultiStarblade list. + Parameters + ---------- + MultiStarblade : list of StarbladeEnergy + A list of instances of an Starblade Energy + iterations : int + The number of steps with the Newton scheme (default: 3). + """ + if multiprocessing: + NewStarblades = list(Pool(processes=3).map(starblade_iteration, + MultiStarblade)) + else: + NewStarblades = [] + for starblade in MultiStarblade: + NewStarblades.append(starblade_iteration(starblade)) + return NewStarblades + +if __name__ == '__main__': + pass + + + -- GitLab