Commit d3ef2730 authored by Jakob Knollmueller's avatar Jakob Knollmueller

finally all errors gone? back to KL, samples now reasonable, performs excellent

parent 8bdaa000
......@@ -4,132 +4,174 @@ from matplotlib import rc
rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
from astropy.io import fits
import starblade as sb
from scipy.stats import invgamma
np.random.seed(42)
if __name__ == '__main__':
s_space = ift.RGSpace([128,128])
h_space = s_space.get_default_codomain()
FFT = ift.FFTOperator(h_space)
p_spec = lambda k: (1./(1+k)**4)
mod_p_spec = lambda k: p_spec(k*1024)*1024**4
binbounds = ift.PowerSpace.useful_binbounds(h_space,logarithmic=True)#, nbin=100)
binbounds = ift.PowerSpace.useful_binbounds(h_space,logarithmic=False)#, nbin=100)
p_space = ift.PowerSpace(h_space,binbounds=binbounds)
k_lengths=p_space.k_lengths
# p_spec = ift.Field(p_space,val=p_spec(p_space.k_lengths))
S = ift.create_power_operator(h_space, power_spectrum=p_spec)
sh = S.draw_sample()
s = FFT(sh)
# k_lengths = sh.domain[0].k_lengths
u = ift.Field(s_space, val = -12.)
# u = 3*(ift.Field.from_random('normal',s_space)-1)
# u.val[20,20] = 3
# u.val[15,96] = 4
# u.val[128,128] = 5
# u.val[65,33] = 6
# u.val[156,119] = 4.5
# u.val[16,125] = 4.5
# u.val[156,51] = 2
# u.val[235,62] = 3.5
# u.val[54,125] = 1.3
x=np.random.randint(0,s_space.shape[0],(s_space.shape[0]/2,1))
y=np.random.randint(0,s_space.shape[0],(s_space.shape[0]/2,1))
brightness = np.random.uniform(-2,5,(s_space.shape[0]/2,1))
u.val[x,y] = brightness
brightness =ift.log(ift.Field(s_space,val=invgamma(0.5).rvs(s_space.shape))/1000)
# u.val[200] = 1
# u.val[300] = 3
# u.val[500] = 4
# u.val[700] = 5
# u.val[900] = 2
# u.val[154] = 0.5
# u.val[421] = 0.25
# u.val[652] = 1
# u.val[1002] = 2.5
# R = ift.FFTSmoothingOperator(s_space,sigma=0.001)
# R(u)
q=1e-3
alpha = 1.5
brightness =ift.log(ift.Field(s_space,val=invgamma.rvs(alpha-1.,scale=q,size=s_space.shape)))
u=brightness
d = ift.exp(s) +ift.exp(u)
data = d.val
energy1 = sb.build_starblade(data,1.25, newton_iterations=5, cg_iterations=500, q=1e-30)#, manual_power_spectrum= mod_p_spec)
energy2 = sb.build_starblade(data,1.5, newton_iterations=5, cg_iterations=500, q=1e-30)#, manual_power_spectrum= mod_p_spec)
energy3 = sb.build_starblade(data,1.75, newton_iterations=5, cg_iterations=500, q=1e-30)#, manual_power_spectrum=mod_p_spec)
hdu = fits.PrimaryHDU(data)
hdul = fits.HDUList([hdu])
hdul.writeto('mock_data.fits',overwrite=True)
sextracted64 = fits.open('check64.fits')
sextracted64 = sextracted64[0].data
sextracted64 = sextracted64.clip(0.001)
sextracted8 = fits.open('check8.fits')
sextracted8 = sextracted8[0].data
sextracted8 = sextracted8.clip(0.001)
energy1 = sb.build_starblade(data,1.0, newton_iterations=200, cg_iterations=3,
q=q)#, manual_power_spectrum= p_spec)
energy2 = sb.build_starblade(data,1.5, newton_iterations=200, cg_iterations=3,
q=q)#, manual_power_spectrum= p_spec)
energy3 = sb.build_starblade(data,3., newton_iterations=200, cg_iterations=3,
q=q)#, manual_power_spectrum= p_spec)
# ift.extra.check_value_gradient_consistency(energy1, tol=1e-3)
for i in range(10):
energy1 = sb.starblade_iteration(energy1, samples=0)
energy2 = sb.starblade_iteration(energy2, samples=0)
energy3 = sb.starblade_iteration(energy3, samples=0)
plt.imsave("2d_data.png",np.log(data))
energy1 = sb.starblade_iteration(energy1, samples=5)
energy2 = sb.starblade_iteration(energy2, samples=5)
energy3 = sb.starblade_iteration(energy3, samples=5)
print "error energy1:", np.sqrt(((ift.exp(energy1.s).val - ift.exp(s).val) ** 2).mean())
print "error energy2:", np.sqrt(((ift.exp(energy2.s).val - ift.exp(s).val) ** 2).mean())
print "error energy3:", np.sqrt(((ift.exp(energy3.s).val - ift.exp(s).val) ** 2).mean())
samples = []
n=30
for i in range(n):
samples.append(energy1.curvature.inverse.draw_sample())
m = 0
v = 0
s_s=0
v_s = 0
pos_tanh = ift.library.nonlinearities.PositiveTanh()
p=0
for sample in samples:
a_s = pos_tanh(sample)
m += a_s
v += a_s**2
s_s += ift.log(d*(1-a_s))
v_s += ift.log(d*(1-a_s))**2
p += ift.power_analyze(FFT.adjoint(s))
m /= n
v /= n
v -= m**2
s_s /= n
v_s /= n
v_s -= s_s**2
p /= n
lim_low = 1e-1
lim_high = 1e4
size = 15
plt.gray()
plt.figure()
plt.imshow(data,norm=LogNorm())
cbar = plt.colorbar()
cbar.set_label('intensity', size=size)
plt.axis('off')
plt.title('logarithmic mock data',size=size)
plt.savefig('2d_data.pdf')
# plt.figure(figsize=(8, 8))
# plt.plot(data, 'k-')
f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True)
f, (ax0, ax1,ax2,ax3,ax4) = plt.subplots(5, sharex=True, sharey=True,figsize=(8, 10))
plt.suptitle('diffuse components', size=size)
ax0.plot(ift.exp(energy1.s).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax0.plot(ift.exp(energy1.s).val[i], 'k-',alpha=(0.15/(energy1.s.val.shape[0])*i))
ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3)
ax0.set_ylabel(r'$\alpha = 1.0$', size=size)
ax0.set_ylim(lim_low ,lim_high)
ax0.set_yscale("log")
ax1.plot(ift.exp(energy2.s).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax1.plot(ift.exp(energy2.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(ift.exp(energy3.s).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax2.plot(ift.exp(energy3.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
ax2.set_ylabel(r'$\alpha = 3.0$', size=size)
for i in range(energy1.s.val.shape[0]):
ax3.plot(sextracted64[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax3.yaxis.set_label_position("right")
ax3.set_ylabel('sextractor'+ '\n'+ r'default', size=size)
for i in range(energy1.s.val.shape[0]):
ax4.plot((sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax4.yaxis.set_label_position("right")
ax4.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size)
plt.savefig('1d_diffuse.pdf')
plt.figure()
f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True)
f, (ax0, ax1,ax2,ax3,ax4) = plt.subplots(5, sharex=True, sharey=True,figsize=(8, 10))
plt.suptitle('point-like components', size=size)
ax0.plot(ift.exp(energy1.u).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax0.plot(ift.exp(energy1.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.25$', size=size)
ax0.set_ylim(1e-1,1e3)
ax0.set_ylabel(r'$\alpha = 1.0$', size=size)
ax0.set_ylim(lim_low ,lim_high)
ax0.set_yscale("log")
ax1.plot(ift.exp(energy2.u).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax1.plot(ift.exp(energy2.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax2.plot(ift.exp(energy3.u).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax2.plot(ift.exp(energy3.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 1.75$', size=size)
ax2.set_ylabel(r'$\alpha = 3.0$', size=size)
for i in range(energy1.s.val.shape[0]):
ax3.plot((data-sextracted64)[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax3.yaxis.set_label_position("right")
ax3.set_ylabel(r'sextractor'+ '\n'+ r'default', size=size)
for i in range(energy1.s.val.shape[0]):
ax4.plot((data-sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax4.yaxis.set_label_position("right")
ax4.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size)
ax0.set_yscale("log")
ax0.set_ylim(1e-1,1e3)
ax0.set_ylim(lim_low ,lim_high)
# plt.ylim(1e-0)
plt.savefig('1d_points.pdf')
plt.figure()
f, (ax0, ax1,ax2) = plt.subplots(3, sharex=True, sharey=True)
plt.suptitle('data and true components', size=size)
ax0.plot(data, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax0.plot(data[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax0.set_yscale("log")
ax0.set_ylim(1e-1,1e3)
ax0.set_ylim(lim_low ,lim_high)
ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'data', size=size)
ax1.plot(ift.exp(s).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax1.plot(ift.exp(s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'diffuse', size=size)
ax2.plot(ift.exp(u).val, 'k-',alpha=0.1)
for i in range(energy1.s.val.shape[0]):
ax2.plot(ift.exp(u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'point-like', size=size)
......@@ -137,25 +179,25 @@ if __name__ == '__main__':
plt.savefig('1d_data.pdf')
plt.figure()
binbounds = ift.PowerSpace.useful_binbounds(energy2.FFT.domain[0],logarithmic=True)#, nbin=100)
power1 = ift.power_analyze(energy2.FFT.inverse((energy1.s)),binbounds=binbounds)
power2 = ift.power_analyze(energy2.FFT.inverse((energy2.s)),binbounds=binbounds)
power3 = ift.power_analyze(energy2.FFT.inverse((energy3.s)),binbounds=binbounds)
binbounds = ift.PowerSpace.useful_binbounds(energy2.FFT.domain[0],logarithmic=False)#, nbin=100)
power1 = ift.power_analyze(energy2.FFT.adjoint((energy1.s)),binbounds=binbounds)
power2 = ift.power_analyze(energy2.FFT.adjoint((energy2.s)),binbounds=binbounds)
power3 = ift.power_analyze(energy2.FFT.adjoint((energy3.s)),binbounds=binbounds)
pp1 = energy1.power_spectrum
pp2 = energy2.power_spectrum
pp3 = energy3.power_spectrum
real_power = ift.power_analyze(sh)
power_d = ift.power_analyze(energy2.FFT.inverse(ift.log(energy2.d)),binbounds=binbounds)
power_d = ift.power_analyze(energy2.FFT.adjoint(ift.log(energy2.d)),binbounds=binbounds)
power_s = ift.power_analyze(energy2.FFT.adjoint(s),binbounds=binbounds)
# power_u = ift.power_analyze(energy2.FFT.inverse(ift.exp(energy2.u)),binbounds=binbounds)
k_lengths=power1.domain[0].k_lengths
plt.plot(k_lengths, p_spec(k_lengths*1024.)*1024**4, 'k-', label='theoretical')
plt.plot(k_lengths, power_d.val, 'k:', label='data')
plt.plot(k_lengths, power1.val, 'k-', label=(r'$\alpha = 1.25$'), alpha=0.6)
# plt.plot(k_lengths, p_spec(k_lengths*1024.)*1024, 'k-', label='theoretical')
plt.plot(k_lengths, power_d.val, 'k+-', label='data')
plt.plot(k_lengths, power1.val, 'k-', label=(r'$\alpha = 1.$'), alpha=0.15)
plt.plot(k_lengths, power2.val, 'k-', label=(r'$\alpha = 1.5$'),alpha=0.3)
plt.plot(k_lengths, power3.val, 'k-', label=(r'$\alpha = 1.75$'), alpha=0.15)
plt.plot(k_lengths, pp1.val, 'r-', label=(r'$\alpha = 1.25$'), alpha=0.6)
plt.plot(k_lengths, pp2.val, 'r-', label=(r'$\alpha = 1.5$'),alpha=0.3)
plt.plot(k_lengths, pp3.val, 'r-', label=(r'$\alpha = 1.75$'), alpha=0.15)
plt.plot(k_lengths, power3.val, 'k-', label=(r'$\alpha = 3.0$'), alpha=0.6)
plt.plot(k_lengths, power_s.val, 'k:', label=('signal'))
# plt.plot(k_lengths, power_u.val, 'k:',label='point-like')
plt.legend()
......@@ -166,4 +208,9 @@ if __name__ == '__main__':
plt.xlabel('harmonic mode',size=15)
plt.savefig('1d_power.pdf')
plt.close('all')
print "error energy1:", np.sqrt(((ift.exp(energy1.s).val-ift.exp(s).val)**2).mean())
print "error energy2:", np.sqrt(((ift.exp(energy2.s).val-ift.exp(s).val)**2).mean())
print "error energy3:", np.sqrt(((ift.exp(energy3.s).val-ift.exp(s).val)**2).mean())
print "error back_size8:", np.sqrt(((sextracted8-ift.exp(s).val)**2).mean())
print "error back_size64:", np.sqrt(((sextracted64-ift.exp(s).val)**2).mean())
......@@ -16,9 +16,10 @@
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, log, exp, DiagonalOperator, create_power_operator
from nifty4 import Energy, Field, log, exp, DiagonalOperator,\
create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler
from nifty4.library import WienerFilterCurvature
from nifty4.library.nonlinearities import PositiveTanh
from nifty4.library.nonlinearities import PositiveTanh, Tanh
class StarbladeEnergy(Energy):
......@@ -64,12 +65,16 @@ class StarbladeEnergy(Energy):
self.update_power = parameters['update_power']
self.newton_iterations = parameters['newton_iterations']
pos_tanh = PositiveTanh()
self.S = self.FFT * self.correlation * self.FFT.adjoint
tanh = Tanh()
self.S = SandwichOperator.make(self.FFT.adjoint, self.correlation)
# self.S = self.FFT * self.correlation * self.FFT.adjoint
self.a = pos_tanh(self.position)
self.a_p = pos_tanh.derivative(self.position)
self.a_pp = -tanh(position)*tanh.derivative(self.position)
self.u = log(self.d * self.a)
self.u_p = self.a_p/self.a
self.u_a = -log(self.a)
self.u_ap = - self.a_p/self.a
one_m_a = 1 - self.a
self.s = log(self.d * one_m_a)
self.s_p = - self.a_p / one_m_a
......@@ -80,24 +85,40 @@ class StarbladeEnergy(Energy):
@property
def value(self):
point = 0
diffuse = 0
det = 0
diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u))
det = self.s.integrate()
det = - self.s.sum()
det += - self.u_a.sum()
det += -log(self.a_p).sum()
det += 0.5 / self.var_x * self.position.vdot(self.position)
return diffuse + point + det
return diffuse + point + det
@property
def gradient(self):
point = 0
diffuse = 0
det = 0
diffuse = self.S.inverse(self.s) * self.s_p
point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p
det = self.position / self.var_x
det += self.s_p
return diffuse + point + det
det += - self.s_p
det += - self.u_ap
det += -1./self.a_p * self.a_pp
return +diffuse + point +det
@property
def curvature(self):
point = self.q * exp(-self.u) * self.u_p ** 2
R = self.FFT.inverse * self.s_p
N = self.correlation
S = DiagonalOperator(1/(point + 1/self.var_x))
return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
# R = self.FFT.inverse * self.s_p
# N = self.correlation
N_inv = DiagonalOperator(point + 1/self.var_x )#+ 2*self.a_p))
R = ScalingOperator(1., point.domain)
S_p = DiagonalOperator(self.s_p)
my_S_inv = SandwichOperator.make(self.FFT.adjoint.inverse.adjoint * S_p, self.correlation.inverse)
curv = InversionEnabler(N_inv + my_S_inv, self.inverter)
return curv
# return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
......@@ -16,7 +16,7 @@
#
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler
from nifty4 import Energy, Field, DiagonalOperator, InversionEnabler, full
from starblade_energy import StarbladeEnergy
class StarbladeKL(Energy):
......@@ -68,7 +68,7 @@ class StarbladeKL(Energy):
@property
def gradient(self):
gradient = Field.zeros(self.position.domain)
gradient = full(self.position.domain,0.)
for energy in self.energy_list:
gradient += energy.gradient
gradient /= len(self.energy_list)
......@@ -76,7 +76,7 @@ class StarbladeKL(Energy):
@property
def curvature(self):
curvature = DiagonalOperator(Field.zeros(self.position.domain))
curvature = DiagonalOperator(full(self.position.domain, 0.))
for energy in self.energy_list:
curvature += energy.curvature
curvature *= Field(self.position.domain,val=1./len(self.energy_list))
......
......@@ -43,16 +43,17 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio
If it is not specified, the algorithm will try to infer it via critical filtering.
"""
s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
s_space = ift.RGSpace(data.shape)#, distances=len(data.shape) * [1])
h_space = s_space.get_default_codomain()
data = ift.Field(s_space,val=data)
FFT = ift.FFTOperator(h_space, target=s_space)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = True)
binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
p_space = ift.PowerSpace(h_space, binbounds=binbounds)
if manual_power_spectrum is None:
initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)),
initial_spectrum = ift.power_analyze(FFT.adjoint(ift.log(data)),
binbounds=p_space.binbounds)
initial_spectrum /= (p_space.k_lengths+1.)**4
initial_spectrum /= 100*(p_space.k_lengths+1.)**4
# initial_spectrum = ift.Field(p_space,val=1e-3)
update_power = True
......@@ -89,19 +90,24 @@ def starblade_iteration(starblade, samples=3):
tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations)
minimizer = ift.RelaxedNewton(controller=controller)
# if len(sample_list)>0:
# energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
# else:
# minimizer = ift.VL_BFGS(controller=controller)
energy = starblade
sample_list = []
for i in range(samples):
sample = energy.curvature.inverse.draw_sample()
sample_list.append(sample)
if len(sample_list)>0:
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
else:
energy = starblade
energy, convergence = minimizer(energy)
energy = StarbladeEnergy(energy.position, parameters=energy.parameters)
sample_list = []
for i in range(samples):
sample = energy.curvature.inverse.draw_sample()
sample_list.append(sample)
if len(sample_list) == 0:
sample_list.append(energy.position)
# energy = StarbladeKL(energy.position, samples=sample_list, parameters=energy.parameters)
new_position = energy.position
new_parameters = energy.parameters
......@@ -171,11 +177,11 @@ def update_power(energy):
if isinstance(energy, StarbladeKL):
power = 0.
for en in energy.energy_list:
power = ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s),
power += ift.power_analyze(energy.parameters['FFT'].inverse(en.s),
binbounds=en.parameters['power_spectrum'].domain[0].binbounds)
# power /= len(energy.energy_list)
power /= len(energy.energy_list)
else:
power = ift.power_analyze(energy.FFT.inverse_times(energy.s),
power = ift.power_analyze(energy.FFT.inverse(energy.s),
binbounds=energy.parameters['power_spectrum'].domain[0].binbounds)
return power
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment