Commit b0e621a6 authored by Jakob Knollmueller's avatar Jakob Knollmueller

Merge branch 'master' into develop

parents fd9e15f6 1385b975
...@@ -25,18 +25,26 @@ import starblade as sb ...@@ -25,18 +25,26 @@ import starblade as sb
if __name__ == '__main__': if __name__ == '__main__':
#specifying location of the input file: #specifying location of the input file:
path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits' path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
data =[1].data path = 'data/frame-i-004874-3-0692.fits'
# data =[1].data
data =[0].data[1000:15000,1250:1750]
data -= data.min() - 0.001
# data = 1.-plt.imread('data/sdss.png').T[0]
# data =[1].data
data = data.clip(min=0.0001)
data = data.clip(min=0.001)
data = np.ndarray.astype(data, float) data = np.ndarray.astype(data, float)
vmin = np.log(data.min()+0.01) vmin = np.log(data.min()+0.2)
vmax = np.log(data.max()) vmax = np.log(data.max())
plt.imsave('data.png', np.log(data),vmin=vmin,vmax=vmax)
alpha = 1.25 alpha = 1.25
Starblade = sb.build_starblade(data, alpha=alpha) Starblade = sb.build_starblade(data, alpha=alpha)
for i in range(10): for i in range(10):
Starblade = sb.starblade_iteration(Starblade) Starblade = sb.starblade_iteration(Starblade, samples=i)
#plotting on logarithmic scale #plotting on logarithmic scale
plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax) plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax)
...@@ -48,5 +56,5 @@ if __name__ == '__main__': ...@@ -48,5 +56,5 @@ if __name__ == '__main__':
plt.yscale('log') plt.yscale('log')
plt.xscale('log') plt.xscale('log')
plt.ylabel('power') plt.ylabel('power')
plt.xscale('harmonic mode') plt.xlabel('harmonic mode')
plt.savefig('power_spectrum.png') plt.savefig('power_spectrum.png')
from .sugar import (build_starblade, starblade_iteration, from .sugar import (build_starblade, starblade_iteration,
build_multi_starblade, multi_starblade_iteration) build_multi_starblade, multi_starblade_iteration)
from .starblade_kl import StarbladeKL
from .starblade_energy import StarbladeEnergy
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2017-2018 Max-Planck-Society
# Author: Jakob Knollmueller
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler
from starblade_energy import StarbladeEnergy
class StarbladeKL(Energy):
"""The Kullback-Leibler divergence for the starblade problem.
position : Field
The current position of the separation.
samples : List
A list containing residual samples.
parameters : Dictionary
Dictionary containing all relevant quantities for the inference,
data : Field
The image data.
alpha : Field
Slope parameter of the point-source prior
q : Field
Cutoff parameter of the point-source prior
power_spectrum : callable or Field
An object that contains the power spectrum of the diffuse component
as a function of the harmonic mode.
FFT : FFTOperator
An operator performing the Fourier transform
inverter : ConjugateGradient
the minimization strategy to use for operator inversion
def __init__(self, position, samples, parameters):
super(StarbladeKL, self).__init__(position=position)
self.samples = samples
self.parameters = parameters
for sample in samples:
energy = StarbladeEnergy(position+sample,parameters)
def at(self, position):
return self.__class__(position, samples=self.samples, parameters=self.parameters)
def value(self):
value = 0.
for energy in self.energy_list:
value += energy.value
value /= len(self.energy_list)
return value
def gradient(self):
gradient = Field.zeros(self.position.domain)
for energy in self.energy_list:
gradient += energy.gradient
gradient /= len(self.energy_list)
return gradient
def curvature(self):
curvature = DiagonalOperator(Field.zeros(self.position.domain))
for energy in self.energy_list:
curvature += energy.curvature
curvature *= Field(self.position.domain,val=1./len(self.energy_list))
return InversionEnabler(curvature, self.parameters['inverter'])
...@@ -21,9 +21,9 @@ from multiprocessing import Pool ...@@ -21,9 +21,9 @@ from multiprocessing import Pool
import nifty4 as ift import nifty4 as ift
from .starblade_energy import StarbladeEnergy from .starblade_energy import StarbladeEnergy
from .starblade_kl import StarbladeKL
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iterations = 3,
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iterations = 3,
manual_power_spectrum = None): manual_power_spectrum = None):
""" Setting up the StarbladeEnergy for the given data and parameters """ Setting up the StarbladeEnergy for the given data and parameters
Parameters Parameters
...@@ -69,27 +69,37 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iteratio ...@@ -69,27 +69,37 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iteratio
inverter=inverter, FFT=FFT, inverter=inverter, FFT=FFT,
newton_iterations=newton_iterations, update_power=update_power) newton_iterations=newton_iterations, update_power=update_power)
Starblade = StarbladeEnergy(position=initial_x, parameters=parameters) Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
return Starblade return Starblade
def starblade_iteration(starblade):
def starblade_iteration(starblade, samples=3):
""" Performing one Newton minimization step """ Performing one Newton minimization step
Parameters Parameters
---------- ----------
starblade : StarbladeEnergy starblade : StarbladeEnergy
An instance of an Starblade Energy An instance of an Starblade Energy
samples : int
Number of samples drawn in order to estimate the KL. If zero the MAP is calculated (default: 3).
""" """
controller = ift.GradientNormController(name="Newton", controller = ift.GradientNormController(name="Newton",
tol_abs_gradnorm=1e-8, tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations) iteration_limit=starblade.newton_iterations)
minimizer = ift.RelaxedNewton(controller=controller) minimizer = ift.RelaxedNewton(controller=controller)
energy, convergence = minimizer(starblade) sample_list = []
for i in range(samples):
sample = starblade.curvature.inverse.draw_sample()
if len(sample_list)>0:
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
energy = starblade
energy, convergence = minimizer(energy)
new_position = energy.position new_position = energy.position
new_parameters = energy.parameters new_parameters = energy.parameters
if energy.update_power: if energy.parameters['update_power']:
h_space = energy.correlation.domain[0] new_power = update_power(energy)
FFT = energy.FFT
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
# new_power /= (new_power.domain[0].k_lengths+1.)**2 # new_power /= (new_power.domain[0].k_lengths+1.)**2
new_parameters['power_spectrum'] = new_power new_parameters['power_spectrum'] = new_power
...@@ -143,6 +153,25 @@ def multi_starblade_iteration(MultiStarblade, processes = 1): ...@@ -143,6 +153,25 @@ def multi_starblade_iteration(MultiStarblade, processes = 1):
NewStarblades.append(starblade_iteration(starblade)) NewStarblades.append(starblade_iteration(starblade))
return NewStarblades return NewStarblades
def update_power(energy):
""" Calculates a new estimate of the power spectrum given a StarbladeEnergy or StarbladeKL.
For Energy the MAP estimate of the power spectrum is calculated and for KL the variational estimate.
energy : StarbladeEnergy or StarbladeKL
An instance of an StarbladeEnergy or StarbladeKL
if isinstance(energy, StarbladeKL):
power = 0.
for en in energy.energy_list:
power += ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s),
power /= len(energy.energy_list)
power = ift.power_analyze(energy.FFT.inverse_times(energy.s),
return power
if __name__ == '__main__': if __name__ == '__main__':
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment