Commit e091563a authored by Jakob Knollmueller's avatar Jakob Knollmueller

enabling KL

parent c2e4c7b9
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2017-2018 Max-Planck-Society
# Author: Jakob Knollmueller
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
import numpy as np
from astropy.io import fits
from matplotlib import pyplot as plt
from multiprocessing import Pool
import nifty4 as ift
from nifty4.library.nonlinearities import PositiveTanh
import starblade as sb
from starblade.starblade_energy import StarbladeEnergy
from starblade.starblade_kl import StarbladeKL
def power_update(KL_energy):
power = 0.
for energy in KL_energy.energy_list:
power += ift.power_analyze(FFT.inverse_times(energy.s),
binbounds=p_space.binbounds)
power /= len(KL_energy.energy_list)
return power
if __name__ == '__main__':
#specifying location of the input file:
path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
path = 'data/frame-u-006174-2-0094.fits'
# path = 'data/frame-g-002821-6-0141.fits'
path = 'data/frame-g-007812-6-0100.fits'
path = 'data/frame-i-004874-3-0692.fits'
# data = fits.open(path)[1].data
data = fits.open(path)[0].data#[1000:,1250:]
data -= data.min() - 0.001
# data = np.exp(2*(1.-plt.imread('data/sdss.png').T[0]))
# data = (plt.imread('data/m51_3.jpg').T[0])
# data = (plt.imread('data/12_FBP.png').T[0])
#
# data = data.clip(min=0.001)
data = np.ndarray.astype(data, float)
vmin = np.log(data.min()+0.01)
vmax = np.log(data.max())
plt.imsave('data.png', np.log(data))
postanh=PositiveTanh()
alpha = 1.5
s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
h_space = s_space.get_default_codomain()
data = ift.Field(s_space,val=data)
FFT = ift.FFTOperator(h_space, target=s_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = ift.PowerSpace(h_space, binbounds=binbounds)
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)),
binbounds=p_space.binbounds)
initial_spectrum /= (p_space.k_lengths+1.)**4
update_power = True
initial_x = ift.Field(s_space, val=-1.)
alpha = ift.Field(s_space, val=alpha)
q = ift.Field(s_space, val=1e-30)
ICI = ift.GradientNormController(iteration_limit=100,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
parameters = dict(data=data, power_spectrum=initial_spectrum,
alpha=alpha, q=q,
inverter=inverter, FFT=FFT,
newton_iterations=5, update_power=update_power)
current_x = initial_x
for i in range(10):
Starblade = StarbladeEnergy(position=current_x, parameters=parameters)
samples = []
for i in range(3):
sample = Starblade.curvature.inverse.draw_sample()
samples.append(sample)
problem = StarbladeKL(current_x, samples,parameters)
controller = ift.GradientNormController(name="Newton",
tol_abs_gradnorm=1e-5,
iteration_limit=5)
minimizer = ift.RelaxedNewton(controller=controller)
problem, convergence = minimizer(problem)
current_x = problem.position
parameters['power_spectrum'] = power_update(problem)
Starblade = StarbladeEnergy(position=current_x, parameters=parameters)
# Starblade = sb.build_starblade(data, alpha=alpha)
# for i in range(10):
# Starblade = sb.starblade_iteration(Starblade)
#
# #plotting on logarithmic scale
plt.imsave('diffuse_component.png', (Starblade.s).val,vmin=vmin, vmax=vmax)
plt.imsave('pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax)
Starblade = StarbladeEnergy(position=current_x, parameters=parameters)
var = 0.
mean = 0
samps = 30
for i in range(samps):
sam = postanh(Starblade.position+Starblade.curvature.inverse.draw_sample())
mean += sam
var += sam**2
var /= samps
mean /= samps
var -= mean**2
mask = ift.sqrt(var) < 0.01 +0.
plt.imsave('masked_points.png', mask.val * Starblade.u.val, vmin=vmin, vmax=vmax)
plt.imsave('masked_diffuse.png', mask.val * Starblade.s.val)
plt.imsave('std.png', np.log(np.sqrt(var.val)*data.val), vmin=-3.3)
# plt.figure()
# k_lenghts = Starblade.power_spectrum.domain[0].k_lengths
# plt.plot(k_lenghts, Starblade.power_spectrum.val)
# plt.title('power spectrum')
# plt.yscale('log')
# plt.xscale('log')
# plt.ylabel('power')
# plt.xscale('harmonic mode')
# plt.savefig('power_spectrum.png')
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2017-2018 Max-Planck-Society
# Author: Jakob Knollmueller
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
import numpy as np
from astropy.io import fits
from matplotlib import pyplot as plt
from scipy.ndimage.filters import median_filter
import starblade as sb
if __name__ == '__main__':
#specifying location of the input file:
# path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
# data = fits.open(path)[1].data
path = 'data/frame-i-004874-3-0692.fits'
path ='data/check.fits'
# data = fits.open(path)[1].data
data = fits.open(path)[0].data[1000:,1250:]
data -= data.min() - 0.001
data = data.clip(min=0.001)
data_true = data.copy()
data = np.ndarray.astype(data, float)
vmin = np.log(data.min()+0.01)
vmax = np.log(data.max())
local_size = 4
for i in range(5):
for i in range(data.shape[0]/local_size):
for j in range(data.shape[1]/local_size):
local_data = data[i*local_size:(1+i)*local_size,j*local_size:(1+j)*local_size]
local_data_median = np.median(local_data)
local_data_var = local_data.var()
local_data = local_data.clip(min=local_data_median - 3*np.sqrt(local_data_var),
max=local_data_median + 3*np.sqrt(local_data_var))
data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] = local_data
background = np.empty_like(data)
crowded = np.zeros_like(data)
for i in range(data.shape[0] / local_size):
for j in range(data.shape[1] / local_size):
local_true_data = data_true[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size]
local_data = data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size]
local_true_var = local_true_data.var()
local_var = local_data.var()
if 0.8 * np.sqrt(local_true_var) > np.sqrt(local_var):
background[i * local_size:(1 + i) * local_size,
j * local_size:(1 + j) * local_size] = 2.5*np.median(local_data)-1.5*local_data.mean()
crowded[i * local_size:(1 + i) * local_size,
j * local_size:(1 + j) * local_size] = 1.
else:
background[i * local_size:(1 + i) * local_size,
j * local_size:(1 + j) * local_size] = local_data.mean()
background = median_filter(background, size=(local_size,local_size))
# alpha = 1.25
# Starblade = sb.build_starblade(data, alpha=alpha)
# for i in range(10):
# Starblade = sb.starblade_iteration(Starblade)
#
# plotting on logarithmic scale
# background += background.min()
plt.gray()
plt.imsave('diffuse_component.png', np.log(background))#, vmin=vmin, vmax=vmax)
plt.imsave('pointlike_component.png', (data_true - background), vmin=vmin, vmax=vmax)
plt.imsave('crowded.png',crowded)
# plt.figure()
# k_lenghts = Starblade.power_spectrum.domain[0].k_lengths
# plt.plot(k_lenghts, Starblade.power_spectrum.val)
# plt.title('power spectrum')
# plt.yscale('log')
# plt.xscale('log')
# plt.ylabel('power')
# plt.xscale('harmonic mode')
# plt.savefig('power_spectrum.png')
......@@ -25,18 +25,26 @@ import starblade as sb
if __name__ == '__main__':
#specifying location of the input file:
path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
data = fits.open(path)[1].data
path = 'data/frame-i-004874-3-0692.fits'
# data = fits.open(path)[1].data
data = fits.open(path)[0].data[1000:15000,1250:1750]
data -= data.min() - 0.001
# data = 1.-plt.imread('data/sdss.png').T[0]
# data = fits.open(path)[1].data
data = data.clip(min=0.0001)
data = data.clip(min=0.001)
data = np.ndarray.astype(data, float)
vmin = np.log(data.min()+0.01)
vmin = np.log(data.min()+0.2)
vmax = np.log(data.max())
plt.imsave('data.png', np.log(data),vmin=vmin,vmax=vmax)
alpha = 1.25
Starblade = sb.build_starblade(data, alpha=alpha)
for i in range(10):
Starblade = sb.starblade_iteration(Starblade)
Starblade = sb.starblade_iteration(Starblade, samples=i)
#plotting on logarithmic scale
plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax)
......@@ -48,5 +56,5 @@ if __name__ == '__main__':
plt.yscale('log')
plt.xscale('log')
plt.ylabel('power')
plt.xscale('harmonic mode')
plt.xlabel('harmonic mode')
plt.savefig('power_spectrum.png')
from .sugar import (build_starblade, starblade_iteration,
build_multi_starblade, multi_starblade_iteration)
from .starblade_kl import StarbladeKL
from .starblade_energy import StarbladeEnergy
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2017-2018 Max-Planck-Society
# Author: Jakob Knollmueller
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler
from starblade_energy import StarbladeEnergy
class StarbladeKL(Energy):
def __init__(self, position, samples, parameters):
super(StarbladeKL, self).__init__(position=position)
self.samples = samples
self.parameters = parameters
self.energy_list=[]
for sample in samples:
energy = StarbladeEnergy(position+sample,parameters)
self.energy_list.append(energy)
def at(self, position):
return self.__class__(position, samples=self.samples, parameters=self.parameters)
@property
def value(self):
value = 0.
for energy in self.energy_list:
value += energy.value
value /= len(self.energy_list)
return value
@property
def gradient(self):
gradient = Field.zeros(self.position.domain)
for energy in self.energy_list:
gradient += energy.gradient
gradient /= len(self.energy_list)
return gradient
@property
def curvature(self):
curvature = DiagonalOperator(Field.zeros(self.position.domain))
for energy in self.energy_list:
curvature += energy.curvature
curvature *= Field(self.position.domain,val=1./len(self.energy_list))
return InversionEnabler(curvature, self.parameters['inverter'])
......@@ -21,9 +21,9 @@ from multiprocessing import Pool
import nifty4 as ift
from .starblade_energy import StarbladeEnergy
from .starblade_kl import StarbladeKL
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iterations = 3,
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iterations = 3,
manual_power_spectrum = None):
""" Setting up the StarbladeEnergy for the given data and parameters
Parameters
......@@ -69,9 +69,12 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iteratio
inverter=inverter, FFT=FFT,
newton_iterations=newton_iterations, update_power=update_power)
Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
return Starblade
def starblade_iteration(starblade):
def starblade_iteration(starblade, samples=3):
""" Performing one Newton minimization step
Parameters
----------
......@@ -82,14 +85,19 @@ def starblade_iteration(starblade):
tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations)
minimizer = ift.RelaxedNewton(controller=controller)
energy, convergence = minimizer(starblade)
sample_list = []
for i in range(samples):
sample = starblade.curvature.inverse.draw_sample()
sample_list.append(sample)
if len(sample_list)>0:
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
else:
energy = starblade
energy, convergence = minimizer(energy)
new_position = energy.position
new_parameters = energy.parameters
if energy.update_power:
h_space = energy.correlation.domain[0]
FFT = energy.FFT
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
if energy.parameters['update_power']:
new_power = update_power(energy)
# new_power /= (new_power.domain[0].k_lengths+1.)**2
new_parameters['power_spectrum'] = new_power
......@@ -143,6 +151,18 @@ def multi_starblade_iteration(MultiStarblade, processes = 1):
NewStarblades.append(starblade_iteration(starblade))
return NewStarblades
def update_power(energy):
if isinstance(energy, StarbladeKL):
power = 0.
for en in energy.energy_list:
power += ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s),
binbounds=en.parameters['power_spectrum'].domain[0].binbounds)
power /= len(energy.energy_list)
else:
power = ift.power_analyze(energy.FFT.inverse_times(energy.s),
binbounds=energy.parameters['power_spectrum'].domain[0].binbounds)
return power
if __name__ == '__main__':
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment