Planned maintenance on Wednesday, 2021-01-20, 17:00-18:00. Expect some interruptions during that time

Commit f4fb46b8 authored by Jakob Knollmueller's avatar Jakob Knollmueller

cleanup

parent 1ec7df35
...@@ -24,7 +24,7 @@ import starblade as sb ...@@ -24,7 +24,7 @@ import starblade as sb
if __name__ == '__main__': if __name__ == '__main__':
#specifying location of the input file: #specifying location of the input file:
path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits' path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
data = fits.open(path)[1].data data = fits.open(path)[1].data
data = data.clip(min=0.001) data = data.clip(min=0.001)
......
...@@ -19,19 +19,22 @@ ...@@ -19,19 +19,22 @@
import numpy as np import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import starblade as sb import starblade as sb
import nifty4 as ift
if __name__ == '__main__': if __name__ == '__main__':
# data = plt.imread('10Keso1242a.tif') # data = plt.imread('10Keso1242a.tif')
data = plt.imread('data/eso1242a.jpg') data = plt.imread('data/galaxy.jpg')
data = data.astype(float) data = data.astype(float)
data = data.clip(0.0001) data = data.clip(0.0001)
alpha = 1.25 alpha = 1.3
MultiStarblade = sb.build_multi_starblade(data, alpha) MultiStarblade = sb.build_multi_starblade(data, alpha, newton_iterations=1)
# power = MultiStarblade[0].power_spectrum / np.arange(len(MultiStarblade[0].power_spectrum.val))**0
# MultiStarblade = sb.build_multi_starblade(data,alpha,newton_iterations=1, manual_power_spectrum=power)
for i in range(10): for i in range(50):
MultiStarblade = sb.multi_starblade_iteration(MultiStarblade, multiprocessing=True) MultiStarblade = sb.multi_starblade_iteration(MultiStarblade, processes=1)
#plotting a three channel RGB image in each iteration #plotting a three channel RGB image in each iteration
diffuse = np.empty_like(data) diffuse = np.empty_like(data)
...@@ -40,5 +43,13 @@ if __name__ == '__main__': ...@@ -40,5 +43,13 @@ if __name__ == '__main__':
diffuse[...,i] = np.exp(MultiStarblade[i].s.val) diffuse[...,i] = np.exp(MultiStarblade[i].s.val)
point[...,i] = np.exp(MultiStarblade[i].u.val) point[...,i] = np.exp(MultiStarblade[i].u.val)
plt.imsave('rgb_diffuse.jpg',diffuse/255.) plt.imsave('rgb_diffuse.jpg',diffuse/255.)
plt.imsave('rgb_point.jpg',point/255.) plt.imsave('rgb_point.jpg',point/255.)
plt.figure()
plt.plot(MultiStarblade[0].power_spectrum.val)
plt.yscale('log')
plt.xscale('log')
plt.savefig('power.jpg')
plt.close()
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# #
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik # Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, log, exp, DiagonalOperator from nifty4 import Energy, Field, log, exp, DiagonalOperator, create_power_operator
from nifty4.library import WienerFilterCurvature from nifty4.library import WienerFilterCurvature
from nifty4.library.nonlinearities import PositiveTanh from nifty4.library.nonlinearities import PositiveTanh
...@@ -38,9 +38,9 @@ class StarbladeEnergy(Energy): ...@@ -38,9 +38,9 @@ class StarbladeEnergy(Energy):
Slope parameter of the point-source prior Slope parameter of the point-source prior
q : Field q : Field
Cutoff parameter of the point-source prior Cutoff parameter of the point-source prior
correlation : Field power_spectrum : callable or Field
A field in the Fourier space which encodes the diagonal of the prior An object that contains the power spectrum of the diffuse component
correlation structure of the diffuse component as a function of the harmonic mode.
FFT : FFTOperator FFT : FFTOperator
An operator performing the Fourier transform An operator performing the Fourier transform
inverter : ConjugateGradient inverter : ConjugateGradient
...@@ -57,9 +57,12 @@ class StarbladeEnergy(Energy): ...@@ -57,9 +57,12 @@ class StarbladeEnergy(Energy):
self.inverter = parameters['inverter'] self.inverter = parameters['inverter']
self.d = parameters['data'] self.d = parameters['data']
self.FFT = parameters['FFT'] self.FFT = parameters['FFT']
self.correlation = parameters['correlation'] self.power_spectrum = parameters['power_spectrum']
self.correlation = create_power_operator(self.FFT.domain, self.power_spectrum)
self.alpha = parameters['alpha'] self.alpha = parameters['alpha']
self.q = parameters['q'] self.q = parameters['q']
self.update_power = parameters['update_power']
self.newton_iterations = parameters['newton_iterations']
pos_tanh = PositiveTanh() pos_tanh = PositiveTanh()
self.S = self.FFT * self.correlation * self.FFT.adjoint self.S = self.FFT * self.correlation * self.FFT.adjoint
self.a = pos_tanh(self.position) self.a = pos_tanh(self.position)
......
...@@ -23,7 +23,8 @@ import nifty4 as ift ...@@ -23,7 +23,8 @@ import nifty4 as ift
from .starblade_energy import StarbladeEnergy from .starblade_energy import StarbladeEnergy
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iterations = 3,
manual_power_spectrum = None):
""" Setting up the StarbladeEnergy for the given data and parameters """ Setting up the StarbladeEnergy for the given data and parameters
Parameters Parameters
---------- ----------
...@@ -35,6 +36,11 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): ...@@ -35,6 +36,11 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
The cutoff parameter of the point source prior (default: 1e-40). The cutoff parameter of the point source prior (default: 1e-40).
cg_iterations : int cg_iterations : int
Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500). Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
newton_iterations : int
Number of consecutive Newton steps within one algorithmic step.(default: 3)
manual_power_spectrum : None, Field or callable
Option to set a manual power spectrum which is kept constant during the separation.
If it is not specified, the algorithm will try to infer it via critical filtering.
""" """
s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1]) s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
...@@ -43,9 +49,14 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): ...@@ -43,9 +49,14 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
FFT = ift.FFTOperator(h_space, target=s_space) FFT = ift.FFTOperator(h_space, target=s_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False) binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = ift.PowerSpace(h_space, binbounds=binbounds) p_space = ift.PowerSpace(h_space, binbounds=binbounds)
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), binbounds=p_space.binbounds) if manual_power_spectrum is None:
initial_spectrum /= (p_space.k_lengths+1.)**2 initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)),
initial_correlation = ift.create_power_operator(h_space, initial_spectrum) binbounds=p_space.binbounds)
initial_spectrum /= (p_space.k_lengths+1.)**2
update_power = True
else:
initial_spectrum = manual_power_spectrum
update_power = False
initial_x = ift.Field(s_space, val=-1.) initial_x = ift.Field(s_space, val=-1.)
alpha = ift.Field(s_space, val=alpha) alpha = ift.Field(s_space, val=alpha)
q = ift.Field(s_space, val=q) q = ift.Field(s_space, val=q)
...@@ -53,38 +64,40 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): ...@@ -53,38 +64,40 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
tol_abs_gradnorm=1e-5) tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(controller=ICI) inverter = ift.ConjugateGradient(controller=ICI)
parameters = dict(data=data, correlation=initial_correlation, parameters = dict(data=data, power_spectrum=initial_spectrum,
alpha=alpha, q=q, alpha=alpha, q=q,
inverter=inverter, FFT=FFT) inverter=inverter, FFT=FFT,
newton_iterations=newton_iterations, update_power=update_power)
Starblade = StarbladeEnergy(position=initial_x, parameters=parameters) Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
return Starblade return Starblade
def starblade_iteration(starblade, iterations=3): def starblade_iteration(starblade):
""" Performing one Newton minimization step """ Performing one Newton minimization step
Parameters Parameters
---------- ----------
starblade : StarbladeEnergy starblade : StarbladeEnergy
An instance of an Starblade Energy An instance of an Starblade Energy
iterations : int
The number of steps with the Newton scheme (default: 3).
""" """
controller = ift.GradientNormController(name="Newton", tol_abs_gradnorm=1e-8, iteration_limit=iterations) controller = ift.GradientNormController(name="Newton",
tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations)
minimizer = ift.RelaxedNewton(controller=controller) minimizer = ift.RelaxedNewton(controller=controller)
energy, convergence = minimizer(starblade) energy, convergence = minimizer(starblade)
new_position = energy.position new_position = energy.position
h_space = energy.correlation.domain[0]
FFT = energy.FFT
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
# new_power /= (new_power.domain[0].k_lengths+1.)**2
new_correlation = ift.create_power_operator(h_space, new_power)
new_parameters = energy.parameters new_parameters = energy.parameters
# new_parameters['correlation'] = new_correlation if energy.update_power:
h_space = energy.correlation.domain[0]
FFT = energy.FFT
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
# new_power /= (new_power.domain[0].k_lengths+1.)**2
new_parameters['power_spectrum'] = new_power
NewStarblade = StarbladeEnergy(new_position, new_parameters) NewStarblade = StarbladeEnergy(new_position, new_parameters)
return NewStarblade return NewStarblade
def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500,
newton_iterations=3, manual_power_spectrum = None):
""" Builds a list of StarbladeEnergies for the given multi-channel dataset """ Builds a list of StarbladeEnergies for the given multi-channel dataset
Parameters Parameters
---------- ----------
...@@ -96,24 +109,33 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500): ...@@ -96,24 +109,33 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
The cutoff parameter of the point source prior (default: 1e-40). The cutoff parameter of the point source prior (default: 1e-40).
cg_iterations : int cg_iterations : int
Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500). Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
newton_iterations : int
Number of consecutive Newton steps within one algorithmic step.(default: 3)
manual_power_spectrum : None, Field or callable
Option to set a manual power spectrum which is kept constant during the separation.
If it is not specified, the algorithm will try to infer it via critical filtering.
""" """
MultiStarblade = [] MultiStarblade = []
for i in range(data.shape[-1]): for i in range(data.shape[-1]):
starblade = build_starblade(data[...,i],alpha=alpha, q=q, cg_iterations=cg_iterations) starblade = build_starblade(data[...,i],alpha=alpha, q=q,
cg_iterations=cg_iterations,
newton_iterations=newton_iterations,
manual_power_spectrum = manual_power_spectrum)
MultiStarblade.append(starblade) MultiStarblade.append(starblade)
return MultiStarblade return MultiStarblade
def multi_starblade_iteration(MultiStarblade, multiprocessing = False): def multi_starblade_iteration(MultiStarblade, processes = 1):
""" Performing one Newton minimization step for all entries of the MultiStarblade list. """ Performing one Newton minimization step for all entries of the MultiStarblade list.
Parameters Parameters
---------- ----------
MultiStarblade : list of StarbladeEnergy MultiStarblade : list of StarbladeEnergy
A list of instances of an Starblade Energy A list of instances of an Starblade Energy
iterations : int processes : int
The number of steps with the Newton scheme (default: 3). Each channel can be computed independently.
This number specifies how many processes are set up.(default: 1)
""" """
if multiprocessing: if processes>1:
NewStarblades = list(Pool(processes=3).map(starblade_iteration, NewStarblades = list(Pool(processes=processes).map(starblade_iteration,
MultiStarblade)) MultiStarblade))
else: else:
NewStarblades = [] NewStarblades = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment