Commit f4fb46b8 authored by Jakob Knollmueller's avatar Jakob Knollmueller

cleanup

parent 1ec7df35
......@@ -24,7 +24,7 @@ import starblade as sb
if __name__ == '__main__':
#specifying location of the input file:
path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
data = fits.open(path)[1].data
data = data.clip(min=0.001)
......
......@@ -19,19 +19,22 @@
import numpy as np
from matplotlib import pyplot as plt
import starblade as sb
import nifty4 as ift
if __name__ == '__main__':
# data = plt.imread('10Keso1242a.tif')
data = plt.imread('data/eso1242a.jpg')
data = plt.imread('data/galaxy.jpg')
data = data.astype(float)
data = data.clip(0.0001)
alpha = 1.25
MultiStarblade = sb.build_multi_starblade(data, alpha)
alpha = 1.3
MultiStarblade = sb.build_multi_starblade(data, alpha, newton_iterations=1)
# power = MultiStarblade[0].power_spectrum / np.arange(len(MultiStarblade[0].power_spectrum.val))**0
# MultiStarblade = sb.build_multi_starblade(data,alpha,newton_iterations=1, manual_power_spectrum=power)
for i in range(10):
MultiStarblade = sb.multi_starblade_iteration(MultiStarblade, multiprocessing=True)
for i in range(50):
MultiStarblade = sb.multi_starblade_iteration(MultiStarblade, processes=1)
#plotting a three channel RGB image in each iteration
diffuse = np.empty_like(data)
......@@ -40,5 +43,13 @@ if __name__ == '__main__':
diffuse[...,i] = np.exp(MultiStarblade[i].s.val)
point[...,i] = np.exp(MultiStarblade[i].u.val)
plt.imsave('rgb_diffuse.jpg',diffuse/255.)
plt.imsave('rgb_point.jpg',point/255.)
plt.figure()
plt.plot(MultiStarblade[0].power_spectrum.val)
plt.yscale('log')
plt.xscale('log')
plt.savefig('power.jpg')
plt.close()
......@@ -16,7 +16,7 @@
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, log, exp, DiagonalOperator
from nifty4 import Energy, Field, log, exp, DiagonalOperator, create_power_operator
from nifty4.library import WienerFilterCurvature
from nifty4.library.nonlinearities import PositiveTanh
......@@ -38,9 +38,9 @@ class StarbladeEnergy(Energy):
Slope parameter of the point-source prior
q : Field
Cutoff parameter of the point-source prior
correlation : Field
A field in the Fourier space which encodes the diagonal of the prior
correlation structure of the diffuse component
power_spectrum : callable or Field
An object that contains the power spectrum of the diffuse component
as a function of the harmonic mode.
FFT : FFTOperator
An operator performing the Fourier transform
inverter : ConjugateGradient
......@@ -57,9 +57,12 @@ class StarbladeEnergy(Energy):
self.inverter = parameters['inverter']
self.d = parameters['data']
self.FFT = parameters['FFT']
self.correlation = parameters['correlation']
self.power_spectrum = parameters['power_spectrum']
self.correlation = create_power_operator(self.FFT.domain, self.power_spectrum)
self.alpha = parameters['alpha']
self.q = parameters['q']
self.update_power = parameters['update_power']
self.newton_iterations = parameters['newton_iterations']
pos_tanh = PositiveTanh()
self.S = self.FFT * self.correlation * self.FFT.adjoint
self.a = pos_tanh(self.position)
......
......@@ -23,7 +23,8 @@ import nifty4 as ift
from .starblade_energy import StarbladeEnergy
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iterations = 3,
manual_power_spectrum = None):
""" Setting up the StarbladeEnergy for the given data and parameters
Parameters
----------
......@@ -35,6 +36,11 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
The cutoff parameter of the point source prior (default: 1e-40).
cg_iterations : int
Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
newton_iterations : int
Number of consecutive Newton steps within one algorithmic step.(default: 3)
manual_power_spectrum : None, Field or callable
Option to set a manual power spectrum which is kept constant during the separation.
If it is not specified, the algorithm will try to infer it via critical filtering.
"""
s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
......@@ -43,9 +49,14 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
FFT = ift.FFTOperator(h_space, target=s_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = ift.PowerSpace(h_space, binbounds=binbounds)
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), binbounds=p_space.binbounds)
initial_spectrum /= (p_space.k_lengths+1.)**2
initial_correlation = ift.create_power_operator(h_space, initial_spectrum)
if manual_power_spectrum is None:
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)),
binbounds=p_space.binbounds)
initial_spectrum /= (p_space.k_lengths+1.)**2
update_power = True
else:
initial_spectrum = manual_power_spectrum
update_power = False
initial_x = ift.Field(s_space, val=-1.)
alpha = ift.Field(s_space, val=alpha)
q = ift.Field(s_space, val=q)
......@@ -53,38 +64,40 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(controller=ICI)
parameters = dict(data=data, correlation=initial_correlation,
parameters = dict(data=data, power_spectrum=initial_spectrum,
alpha=alpha, q=q,
inverter=inverter, FFT=FFT)
inverter=inverter, FFT=FFT,
newton_iterations=newton_iterations, update_power=update_power)
Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
return Starblade
def starblade_iteration(starblade, iterations=3):
def starblade_iteration(starblade):
""" Performing one Newton minimization step
Parameters
----------
starblade : StarbladeEnergy
An instance of an Starblade Energy
iterations : int
The number of steps with the Newton scheme (default: 3).
"""
controller = ift.GradientNormController(name="Newton", tol_abs_gradnorm=1e-8, iteration_limit=iterations)
controller = ift.GradientNormController(name="Newton",
tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations)
minimizer = ift.RelaxedNewton(controller=controller)
energy, convergence = minimizer(starblade)
new_position = energy.position
h_space = energy.correlation.domain[0]
FFT = energy.FFT
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
# new_power /= (new_power.domain[0].k_lengths+1.)**2
new_correlation = ift.create_power_operator(h_space, new_power)
new_parameters = energy.parameters
# new_parameters['correlation'] = new_correlation
if energy.update_power:
h_space = energy.correlation.domain[0]
FFT = energy.FFT
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
# new_power /= (new_power.domain[0].k_lengths+1.)**2
new_parameters['power_spectrum'] = new_power
NewStarblade = StarbladeEnergy(new_position, new_parameters)
return NewStarblade
def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500,
newton_iterations=3, manual_power_spectrum = None):
""" Builds a list of StarbladeEnergies for the given multi-channel dataset
Parameters
----------
......@@ -96,24 +109,33 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
The cutoff parameter of the point source prior (default: 1e-40).
cg_iterations : int
Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
newton_iterations : int
Number of consecutive Newton steps within one algorithmic step.(default: 3)
manual_power_spectrum : None, Field or callable
Option to set a manual power spectrum which is kept constant during the separation.
If it is not specified, the algorithm will try to infer it via critical filtering.
"""
MultiStarblade = []
for i in range(data.shape[-1]):
starblade = build_starblade(data[...,i],alpha=alpha, q=q, cg_iterations=cg_iterations)
starblade = build_starblade(data[...,i],alpha=alpha, q=q,
cg_iterations=cg_iterations,
newton_iterations=newton_iterations,
manual_power_spectrum = manual_power_spectrum)
MultiStarblade.append(starblade)
return MultiStarblade
def multi_starblade_iteration(MultiStarblade, multiprocessing = False):
def multi_starblade_iteration(MultiStarblade, processes = 1):
""" Performing one Newton minimization step for all entries of the MultiStarblade list.
Parameters
----------
MultiStarblade : list of StarbladeEnergy
A list of instances of an Starblade Energy
iterations : int
The number of steps with the Newton scheme (default: 3).
processes : int
Each channel can be computed independently.
This number specifies how many processes are set up.(default: 1)
"""
if multiprocessing:
NewStarblades = list(Pool(processes=3).map(starblade_iteration,
if processes>1:
NewStarblades = list(Pool(processes=processes).map(starblade_iteration,
MultiStarblade))
else:
NewStarblades = []
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment