Commit 8c9c6556 authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

paper plots

parent b6ed499e
......@@ -64,7 +64,7 @@ if __name__ == '__main__':
plt.imsave('point_sex.png', (data_unmod - sex), vmax = lin_max,vmin =lin_min)
alpha = 1.4
Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=10)
Starblade = sb.build_starblade(data, alpha=alpha, cg_steps=100, newton_steps=10)
for i in range(1000):
Starblade = sb.starblade_iteration(Starblade, samples=5)
......
......@@ -70,7 +70,7 @@ if __name__ == '__main__':
plt.imsave('log_point_sex.png', np.log((data_unmod - sex).clip(min=0.0001)), vmax = vmax,vmin =vmin)
alpha = 1.4
Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=3)
Starblade = sb.build_starblade(data, alpha=alpha, cg_steps=100, newton_steps=3)
for i in range(10):
Starblade = sb.starblade_iteration(Starblade, samples=5)
......
import nifty4 as ift
import numpy as np
import matplotlib as mpl
mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}'] #for \text command
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)
......@@ -43,37 +46,48 @@ if __name__ == '__main__':
dcnn_diffuse = dcnn_diffuse.clip(0.001)
dcnn_points = dcnn_points.clip(0.001)
energy1 = sb.build_starblade(data,1.0, newton_iterations=200, cg_iterations=10,
energy1 = sb.build_starblade(data,1.0, newton_steps=200, cg_steps=5,
q=q)#, manual_power_spectrum= p_spec)
energy2 = sb.build_starblade(data,1.5, newton_iterations=200, cg_iterations=10,
energy2 = sb.build_starblade(data,1.5, newton_steps=200, cg_steps=5,
q=q)#, manual_power_spectrum= p_spec)
energy3 = sb.build_starblade(data,3., newton_iterations=200, cg_iterations=10,
energy3 = sb.build_starblade(data,3., newton_steps=200, cg_steps=5,
q=q)#, manual_power_spectrum= p_spec)
# ift.extra.check_value_gradient_consistency(energy1, tol=1e-3)
for i in range(5):
energy1 = sb.starblade_iteration(energy1, samples=3)
energy2 = sb.starblade_iteration(energy2, samples=3)
energy3 = sb.starblade_iteration(energy3, samples=3)
for i in range(30):
energy1 = sb.starblade_iteration(energy1, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000)
energy2 = sb.starblade_iteration(energy2, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000)
energy3 = sb.starblade_iteration(energy3, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000)
print "error energy1:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy1.s).val) ** 2).mean())
print "error energy2:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy2.s).val) ** 2).mean())
print "error energy3:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy3.s).val) ** 2).mean())
# energy2 = sb.starblade_iteration(energy2, samples=10 , cg_steps=1000, newton_steps=30)
samples = []
n=30
n=1000
for i in range(n):
samples.append(energy2.curvature.inverse.draw_sample())
print samples[i].var()
m = 0
v = 0
s_s=0
v_s = 0
pos_tanh = ift.library.nonlinearities.PositiveTanh()
p=0
RMS = 0
RMS_s = 0
for sample in samples:
a_s = pos_tanh(energy2.position+sample)
m += a_s
v += a_s**2
s_s += ift.log(d*(1-a_s))
v_s += ift.log(d*(1-a_s))**2
sam_s = ift.log(d*(1-a_s))
RMS_s += np.sqrt(((sam_s - energy2.s)**2).mean())
s_s += sam_s
v_s += (sam_s)**2
print np.sqrt(((sample**2 ).mean()))
RMS += np.sqrt(((sample**2 ).mean()))
p += ift.power_analyze(FFT.adjoint(s))
m /= n
v /= n
......@@ -82,14 +96,18 @@ if __name__ == '__main__':
v_s /= n
v_s -= s_s**2
p /= n
RMS /=n
RMS_s /=n
true_x = np.arctanh(np.exp(brightness.val) / data * 2 - 1)
print "RMS x:", np.sqrt(((energy2.position.val - true_x)**2).mean()), RMS
lim_low = 1e-1
lim_high = 1e4
size = 15
plt.gray()
plt.figure()
plt.imshow(data,norm=LogNorm())
plt.imshow(data,norm=LogNorm(), vmin=lim_low,vmax=lim_high)
cbar = plt.colorbar()
cbar.set_label('intensity', size=size)
plt.axis('off')
......@@ -103,31 +121,31 @@ if __name__ == '__main__':
for i in range(energy1.s.val.shape[0]):
ax0.plot(ift.exp(energy1.s).val[i], 'k-',alpha=(0.15/(energy1.s.val.shape[0])*i))
ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.0$', size=size)
ax0.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.0$', size=size)
ax0.set_ylim(lim_low ,lim_high)
ax0.set_yscale("log")
for i in range(energy1.s.val.shape[0]):
ax1.plot(ift.exp(energy2.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax1.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.5$', size=size)
for i in range(energy1.s.val.shape[0]):
ax2.plot(ift.exp(energy3.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 3.0$', size=size)
ax2.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 3.0$', size=size)
for i in range(energy1.s.val.shape[0]):
ax3.plot(sextracted64[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax3.yaxis.set_label_position("right")
ax3.set_ylabel('sextractor'+ '\n'+ r'default', size=size)
ax3.set_ylabel('SExtractor'+ '\n'+ r'$64\times 64$', size=size)
for i in range(energy1.s.val.shape[0]):
ax4.plot((sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax4.yaxis.set_label_position("right")
ax5.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size)
ax4.set_ylabel(r'SExtractor'+'\n' + r'$8\times 8$', size=size)
for i in range(energy1.s.val.shape[0]):
ax5.plot((dcnn_diffuse)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax5.yaxis.set_label_position("right")
ax5.set_ylabel(r'DCAE', size=size)
ax5.set_ylabel(r'DAE', size=size)
plt.savefig('1d_diffuse.pdf')
......@@ -138,28 +156,28 @@ if __name__ == '__main__':
for i in range(energy1.s.val.shape[0]):
ax0.plot(ift.exp(energy1.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax0.yaxis.set_label_position("right")
ax0.set_ylabel(r'$\alpha = 1.0$', size=size)
ax0.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.0$', size=size)
ax0.set_ylim(lim_low ,lim_high)
ax0.set_yscale("log")
for i in range(energy1.s.val.shape[0]):
ax1.plot(ift.exp(energy2.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax1.yaxis.set_label_position("right")
ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
ax1.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.5$', size=size)
for i in range(energy1.s.val.shape[0]):
ax2.plot(ift.exp(energy3.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax2.yaxis.set_label_position("right")
ax2.set_ylabel(r'$\alpha = 3.0$', size=size)
ax2.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 3.0$', size=size)
for i in range(energy1.s.val.shape[0]):
ax3.plot((data-sextracted64)[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax3.yaxis.set_label_position("right")
ax3.set_ylabel(r'sextractor'+ '\n'+ r'default', size=size)
ax3.set_ylabel(r'SExtractor'+ '\n'+ r'$64\times 64$', size=size)
for i in range(energy1.s.val.shape[0]):
ax4.plot((data-sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax4.yaxis.set_label_position("right")
ax4.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size)
ax4.set_ylabel(r'SExtractor'+'\n' + r'$8 \times 8$', size=size)
for i in range(energy1.s.val.shape[0]):
ax5.plot((dcnn_points)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
ax5.set_ylabel(r'DCAE', size=size)
ax5.set_ylabel(r'DAE', size=size)
ax5.yaxis.set_label_position("right")
ax0.set_yscale("log")
......@@ -206,26 +224,52 @@ if __name__ == '__main__':
k_lengths=power1.domain[0].k_lengths
# plt.plot(k_lengths, p_spec(k_lengths*1024.)*1024, 'k-', label='theoretical')
plt.plot(k_lengths, power_d.val, 'k+-', label='data')
plt.plot(k_lengths, power1.val, 'k-', label=(r'$\alpha = 1.$'), alpha=0.15)
plt.plot(k_lengths, power2.val, 'k-', label=(r'$\alpha = 1.5$'),alpha=0.3)
plt.plot(k_lengths, power3.val, 'k-', label=(r'$\alpha = 3.0$'), alpha=0.6)
plt.plot(k_lengths, power_s.val, 'k:', label=('signal'))
plt.plot(k_lengths, power1.val, 'b-', label=(r'$\alpha = 1.$'), alpha=0.7)
plt.plot(k_lengths, power2.val, 'r-', label=(r'$\alpha = 1.5$'),alpha=0.7)
plt.plot(k_lengths, power3.val, 'g-', label=(r'$\alpha = 3.0$'), alpha=0.7)
plt.plot(k_lengths, power_s.val, 'k-', label=('signal'))
# plt.plot(k_lengths, power_u.val, 'k:',label='point-like')
plt.legend()
plt.yscale('log')
plt.xscale('log')
plt.title('power spectra',size=15)
plt.ylabel('power',size=15)
plt.xlabel('harmonic mode',size=15)
plt.title('power spectra',size=size)
plt.ylabel('power',size=size)
plt.xlabel('harmonic mode',size=size)
plt.savefig('1d_power.pdf')
plt.figure()
plt.scatter(s.val, np.log(dcnn_diffuse), alpha = 0.2, c='k', label='DDCAE')
plt.scatter(s.val, energy2.s.val, alpha = 0.2, c='r', label = 'starblade')
plt.title('diffuse truth vs DDCAE and starblade')
subset = (np.random.randint(0,128,5000),np.random.randint(0,128,5000))
plt.figure(figsize=(6,4))
plt.scatter(np.exp(s.val)[subset], sextracted8[subset], alpha = .2, c='b',marker='.', label='SExtractor',s=1.5)
plt.scatter(np.exp(s.val)[subset], dcnn_diffuse[subset], alpha = .7, c='r',marker='.', label='DAE',s=1.5)
plt.scatter(np.exp(s.val)[subset], np.exp(energy2.s.val)[subset], alpha = .7, c='k', marker='.', label = 'starblade',s=1.5)
plt.ylabel('method',size=size)
plt.xlabel('truth',size=size)
plt.yscale('log')
plt.xscale('log')
plt.title('diffuse truth vs recovered',size=size)
plt.legend()
plt.savefig('scatter.pdf')
plt.close('all')
plt.figure()
plt.imshow((np.sqrt(v.val)))
plt.colorbar()
plt.axis('off')
plt.title('estimated separation uncertainty', size=size)
cbar.set_label('standard deviation', size=size)
plt.savefig('uncertainty.pdf')
subset = (np.random.randint(0,128,200),np.random.randint(0,128,200))
plt.figure()
plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
s =10000*np.log(np.sqrt(((s.val[subset]-energy2.s.val[subset])**2))).flatten(),marker='o', alpha=0.5 )
plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
s=5000 * np.sqrt(v_s.val[subset]).flatten(), marker='o', alpha=0.5)
plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
s =5000*np.sqrt(((s.val[subset]-np.log(dcnn_diffuse[subset]))**2)).flatten(),marker='o', alpha=0.2 )
plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
s =5000*np.sqrt(((s.val[subset]-np.log(sextracted8[subset]))**2)).flatten(),marker='o', alpha=0.1 )
print "flux error energy1:", np.sqrt((((1 - ift.exp(s).val / ift.exp(energy1.s).val) ** 2)*d.val/d.sum()).mean())
print "flux error energy2:", np.sqrt((((1 - ift.exp(s).val / ift.exp(energy2.s).val) ** 2)*d.val/d.sum()).mean())
......@@ -239,3 +283,13 @@ if __name__ == '__main__':
print "class error back_size8:", np.sqrt(((1-sextracted8/ift.exp(s).val)**2).mean())
print "class error back_size64:", np.sqrt((((1-sextracted64/ift.exp(s).val)**2)).mean())
print "class error dcnn:", np.sqrt((((1-dcnn_diffuse/ift.exp(s).val)**2)).mean())
print "RMS energy1:", np.sqrt(((energy1.s.val - s.val) ** 2).mean())
print "RMS energy2:", np.sqrt(((energy2.s.val - s.val) ** 2).mean())
print "RMS energy3:", np.sqrt(((energy3.s.val - s.val) ** 2).mean())
print "RMS a:", np.sqrt(((energy2.a.val - np.exp(brightness.val)/data) ** 2).mean())
print "RMS sextractor8:", np.sqrt(((np.log(sextracted8) - s.val) ** 2).mean())
print "RMS sextractor64:", np.sqrt(((np.log(sextracted64) - s.val) ** 2).mean())
print "RMS dcnn:", np.sqrt(((np.log(dcnn_diffuse) - s.val) ** 2).mean())
from keras.layers import Activation, Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose, BatchNormalization, merge, Dropout
from keras.models import Model
def DDCAE_model(image_size):
input_shape = (image_size, image_size,1)
batch_size = 32
kernel_size = 3
latent_dim = 8
# Encoder/Decoder number of CNN layers and filters per layer
layer_filters = [32,16]
# Build the Autoencoder Model
# First build the Encoder Model
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
# Stack of Conv2D blocks
# Notes:
# 1) Use Batch Normalization before ReLU on deep networks
# 2) Use MaxPooling2D as alternative to strides>1
# - faster but not as good as strides>1
for filters in layer_filters:
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=1,
activation='relu',
padding='same',
data_format='channels_last')(x)
# x = BatchNormalization()(x)
x = Dropout(0.1)(x)
for filters in layer_filters[::-1]:
x = Conv2DTranspose(filters=filters,
kernel_size=kernel_size,
strides=1,
activation='relu',
padding='same',
data_format='channels_last')(x)
# x = BatchNormalization()(x)
x = Dropout(0.1)(x)
x = Conv2DTranspose(filters=1,
kernel_size=kernel_size,
padding='same')(x)
outputs = Activation('sigmoid', name='decoder_output')(x)
outputs = merge([inputs,outputs],mode='mul')
return Model(inputs, outputs, name='autoencoder')
......@@ -9,32 +9,40 @@ from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1 import AxesGrid
from astropy.io import fits
from scipy.ndimage import zoom
import starblade as sb
from DDCAE import DDCAE_model
np.random.seed(42)
if __name__ == '__main__':
path = 'demos/data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
path = '../demos/data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
data = fits.open(path)[1].data
data = data.clip(min=0.001)
# data = data.clip(min=0.001)
data = np.ndarray.astype(data, float)
alpha = 1.28
#data = zoom(data, 0.25)
data = data.clip(min=0.001)
hdu = fits.PrimaryHDU(data)
hdul = fits.HDUList([hdu])
hdul.writeto('my_hubble_data.fits',overwrite=True)
alpha = 1.1
myEnergy = sb.build_starblade(data, alpha=alpha)
for i in range(10):
myEnergy = sb.starblade_iteration(myEnergy)
A = FFTSmoothingOperator(myEnergy.s.domain, sigma=2.)
for i in range(30):
myEnergy = sb.starblade_iteration(myEnergy, samples=5,newton_steps=30,cg_steps=10, sampling_steps=100)
print i
A = FFTSmoothingOperator(myEnergy.s.domain, sigma=.002)
plt.magma()
size = 15
vmin = data.min()+0.01
vmax = 0.01*data.max()
vmax = data.max()*0.01
plt.figure()
plt.title('diffuse emission', size=size)
plt.axis('off')
......@@ -81,16 +89,16 @@ if __name__ == '__main__':
plt.suptitle('zoomed in section', size=size)
# fig.tight_layout()
vmin = data.min() + 0.0001
vmax = 0.001 * data.max()
im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
vmin_small = data.min() + 0.0001
vmax_small = 0.001 * data.max()
im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small))
ax[0].set_title('data', size=15)
ax[0].axis('off')
ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small))
ax[1].set_title('diffuse', size=15)
ax[1].axis('off')
ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small))
ax[2].set_title('point-like', size=15)
ax[2].axis('off')
......@@ -140,7 +148,7 @@ if __name__ == '__main__':
k_lengths = power.domain[0].k_lengths
plt.plot(k_lengths, power.val, 'k-', label='diffuse')
plt.plot(k_lengths, power_d.val, 'k:', label='data')
plt.plot(k_lengths, power_d.val, 'k+-', label='data')
plt.legend()
plt.yscale('log')
plt.xscale('log')
......@@ -148,3 +156,55 @@ if __name__ == '__main__':
plt.ylabel('power',size=15)
plt.xlabel('harmonic mode',size=15)
plt.savefig('hubble_log_power.pdf')
DDCAE=DDCAE_model(None)
DDCAE.load_weights('DDCAC',by_name=True)
ddd = data.reshape([1,1000,1000,1])
separation = DDCAE.predict(ddd)
plt.figure()
plt.title('DAE diffuse emission', size=size)
plt.axis('off')
ax = plt.gca()
im = ax.imshow(separation[0,:,:,0],norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.set_label('flux', size=size)
plt.tight_layout()
plt.savefig('DAE_hubble_diffuse.pdf')
plt.figure()
plt.title('DAE point-like emission', size=size)
plt.axis('off')
ax = plt.gca()
im = ax.imshow(A(ift.Field(myEnergy.position.domain,data-separation[0,:,:,0])).val,norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.set_label('flux', size=size)
plt.tight_layout()
plt.savefig('DAE_hubble_point_like.pdf')
sextracted = fits.open('check_hubble64.fits')[0].data
plt.figure()
plt.title('SExtractor diffuse emission', size=size)
plt.axis('off')
ax = plt.gca()
im = ax.imshow(sextracted.clip(0.001),norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.set_label('flux', size=size)
plt.tight_layout()
plt.savefig('SExtractor_hubble_diffuse.pdf')
plt.figure()
plt.title('SExtractor point-like emission', size=size)
plt.axis('off')
ax = plt.gca()
im = ax.imshow(A(ift.Field(myEnergy.position.domain,(data-sextracted))).val.clip(0.001),norm=LogNorm(vmin=vmin, vmax=vmax))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.set_label('flux', size=size)
plt.tight_layout()
plt.savefig('SExtractor_hubble_point_like.pdf')
......@@ -17,7 +17,7 @@
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
from nifty4 import Energy, Field, log, exp, DiagonalOperator,\
create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler
create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler, SamplingEnabler
from nifty4.library import WienerFilterCurvature
from nifty4.library.nonlinearities import PositiveTanh, Tanh
......@@ -56,6 +56,7 @@ class StarbladeEnergy(Energy):
self.parameters = parameters
self.inverter = parameters['inverter']
self.sampling_inverter = parameters['sampling_inverter']
self.d = parameters['data']
self.FFT = parameters['FFT']
self.power_spectrum = parameters['power_spectrum']
......@@ -115,10 +116,11 @@ class StarbladeEnergy(Energy):
point = self.q * exp(-self.u) * self.u_p ** 2
# R = self.FFT.inverse * self.s_p
# N = self.correlation
N_inv = DiagonalOperator(point + 1/self.var_x )#+ 2*self.a_p))
O_x = DiagonalOperator(Field(self.position.domain,val=1./self.var_x))
N_inv = DiagonalOperator(point )#+ 2*self.a_p))
R = ScalingOperator(1., point.domain)
S_p = DiagonalOperator(self.s_p)
my_S_inv = SandwichOperator.make(self.FFT.adjoint.inverse.adjoint * S_p, self.correlation.inverse)
curv = InversionEnabler(N_inv + my_S_inv, self.inverter)
curv = InversionEnabler(SamplingEnabler(my_S_inv+N_inv, O_x, self.sampling_inverter), self.inverter)
return curv
# return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
......@@ -23,8 +23,8 @@ import nifty4 as ift
from .starblade_energy import StarbladeEnergy
from .starblade_kl import StarbladeKL
def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iterations = 3,
manual_power_spectrum = None):
def build_starblade(data, alpha=1.5, q=1e-40, cg_steps=100, newton_steps = 3,
manual_power_spectrum = None, sampling_steps = 100):
""" Setting up the StarbladeEnergy for the given data and parameters
Parameters
----------
......@@ -34,9 +34,9 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio
The slope parameter of the point source prior (default: 1.5).
q : float
The cutoff parameter of the point source prior (default: 1e-40).
cg_iterations : int
cg_steps : int
Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
newton_iterations : int
newton_steps : int
Number of consecutive Newton steps within one algorithmic step.(default: 3)
manual_power_spectrum : None, Field or callable
Option to set a manual power spectrum which is kept constant during the separation.
......@@ -63,21 +63,24 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio
initial_x = ift.Field(s_space, val=-1.)
alpha = ift.Field(s_space, val=alpha)
q = ift.Field(s_space, val=q)
ICI = ift.GradientNormController(iteration_limit=cg_iterations,
ICI = ift.GradientNormController(iteration_limit=cg_steps,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(controller=ICI)
IC_samples = ift.GradientNormController(iteration_limit=sampling_steps,
tol_abs_gradnorm=1e-5)
sampling_inverter = ift.ConjugateGradient(controller=IC_samples)
parameters = dict(data=data, power_spectrum=initial_spectrum,
alpha=alpha, q=q,
inverter=inverter, FFT=FFT,
newton_iterations=newton_iterations, update_power=update_power)
newton_iterations=newton_steps, sampling_inverter=sampling_inverter, update_power=update_power)
Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
return Starblade
def starblade_iteration(starblade, samples=3):
def starblade_iteration(starblade, samples=3, cg_steps=10, newton_steps=3, sampling_steps=100):
""" Performing one Newton minimization step
Parameters
----------
......@@ -88,16 +91,26 @@ def starblade_iteration(starblade, samples=3):
"""
controller = ift.GradientNormController(name="Newton",
tol_abs_gradnorm=1e-8,
iteration_limit=starblade.newton_iterations)
iteration_limit=newton_steps)
minimizer = ift.RelaxedNewton(controller=controller)
ICI = ift.GradientNormController(iteration_limit=cg_steps,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(controller=ICI)
IC_samples = ift.GradientNormController(iteration_limit=sampling_steps,
tol_abs_gradnorm=1e-5)
sampling_inverter = ift.ConjugateGradient(controller=IC_samples)
# minimizer = ift.VL_BFGS(controller=controller)
energy = starblade
para = starblade.parameters
para['inverter'] = inverter
para['sampling_inverter'] = sampling_inverter
energy = StarbladeEnergy(starblade.position,parameters=para)
sample_list = []
for i in range(samples):
sample = energy.curvature.inverse.draw_sample()
sample_list.append(sample)
if len(sample_list)>0:
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
energy = StarbladeKL(starblade.position, samples=sample_list, parameters=energy.parameters)
else:
energy = starblade
energy, convergence = minimizer(energy)
......@@ -140,9 +153,9 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500,
"""
MultiStarblade = []
for i in range(data.shape[-1]):
starblade = build_starblade(data[...,i],alpha=alpha, q=q,
cg_iterations=cg_iterations,
newton_iterations=newton_iterations,
starblade = build_starblade(data[...,i], alpha=alpha, q=q,