From 8c9c655611fc949e544bc9b19f40f3bb7b442598 Mon Sep 17 00:00:00 2001 From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de> Date: Mon, 18 Jun 2018 14:32:56 +0200 Subject: [PATCH] paper plots --- demos/demo.py | 2 +- demos/sextractor/compare_sex.py | 2 +- paper/1d_separation.py | 124 +++++++++++++++++++++++--------- paper/DDCAE.py | 50 +++++++++++++ paper/hubble_separation.py | 92 +++++++++++++++++++----- starblade/starblade_energy.py | 8 ++- starblade/sugar.py | 39 ++++++---- 7 files changed, 248 insertions(+), 69 deletions(-) create mode 100644 paper/DDCAE.py diff --git a/demos/demo.py b/demos/demo.py index 132b798..8791d9a 100644 --- a/demos/demo.py +++ b/demos/demo.py @@ -64,7 +64,7 @@ if __name__ == '__main__': plt.imsave('point_sex.png', (data_unmod - sex), vmax = lin_max,vmin =lin_min) alpha = 1.4 - Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=10) + Starblade = sb.build_starblade(data, alpha=alpha, cg_steps=100, newton_steps=10) for i in range(1000): Starblade = sb.starblade_iteration(Starblade, samples=5) diff --git a/demos/sextractor/compare_sex.py b/demos/sextractor/compare_sex.py index fa62758..0d38cfb 100644 --- a/demos/sextractor/compare_sex.py +++ b/demos/sextractor/compare_sex.py @@ -70,7 +70,7 @@ if __name__ == '__main__': plt.imsave('log_point_sex.png', np.log((data_unmod - sex).clip(min=0.0001)), vmax = vmax,vmin =vmin) alpha = 1.4 - Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=3) + Starblade = sb.build_starblade(data, alpha=alpha, cg_steps=100, newton_steps=3) for i in range(10): Starblade = sb.starblade_iteration(Starblade, samples=5) diff --git a/paper/1d_separation.py b/paper/1d_separation.py index 7553c72..efbe2f1 100644 --- a/paper/1d_separation.py +++ b/paper/1d_separation.py @@ -1,5 +1,8 @@ import nifty4 as ift import numpy as np +import matplotlib as mpl +mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}'] #for \text command + from matplotlib import rc rc('font',**{'family':'serif','serif':['Palatino']}) rc('text', usetex=True) @@ -43,37 +46,48 @@ if __name__ == '__main__': dcnn_diffuse = dcnn_diffuse.clip(0.001) dcnn_points = dcnn_points.clip(0.001) - energy1 = sb.build_starblade(data,1.0, newton_iterations=200, cg_iterations=10, + energy1 = sb.build_starblade(data,1.0, newton_steps=200, cg_steps=5, q=q)#, manual_power_spectrum= p_spec) - energy2 = sb.build_starblade(data,1.5, newton_iterations=200, cg_iterations=10, + energy2 = sb.build_starblade(data,1.5, newton_steps=200, cg_steps=5, q=q)#, manual_power_spectrum= p_spec) - energy3 = sb.build_starblade(data,3., newton_iterations=200, cg_iterations=10, + energy3 = sb.build_starblade(data,3., newton_steps=200, cg_steps=5, q=q)#, manual_power_spectrum= p_spec) # ift.extra.check_value_gradient_consistency(energy1, tol=1e-3) - for i in range(5): - energy1 = sb.starblade_iteration(energy1, samples=3) - energy2 = sb.starblade_iteration(energy2, samples=3) - energy3 = sb.starblade_iteration(energy3, samples=3) + for i in range(30): + + energy1 = sb.starblade_iteration(energy1, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000) + energy2 = sb.starblade_iteration(energy2, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000) + energy3 = sb.starblade_iteration(energy3, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000) print "error energy1:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy1.s).val) ** 2).mean()) print "error energy2:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy2.s).val) ** 2).mean()) print "error energy3:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy3.s).val) ** 2).mean()) + # energy2 = sb.starblade_iteration(energy2, samples=10 , cg_steps=1000, newton_steps=30) samples = [] - n=30 + n=1000 + for i in range(n): samples.append(energy2.curvature.inverse.draw_sample()) + print samples[i].var() m = 0 v = 0 s_s=0 v_s = 0 pos_tanh = ift.library.nonlinearities.PositiveTanh() p=0 + RMS = 0 + RMS_s = 0 + for sample in samples: a_s = pos_tanh(energy2.position+sample) m += a_s v += a_s**2 - s_s += ift.log(d*(1-a_s)) - v_s += ift.log(d*(1-a_s))**2 + sam_s = ift.log(d*(1-a_s)) + RMS_s += np.sqrt(((sam_s - energy2.s)**2).mean()) + s_s += sam_s + v_s += (sam_s)**2 + print np.sqrt(((sample**2 ).mean())) + RMS += np.sqrt(((sample**2 ).mean())) p += ift.power_analyze(FFT.adjoint(s)) m /= n v /= n @@ -82,14 +96,18 @@ if __name__ == '__main__': v_s /= n v_s -= s_s**2 p /= n + RMS /=n + RMS_s /=n + true_x = np.arctanh(np.exp(brightness.val) / data * 2 - 1) + print "RMS x:", np.sqrt(((energy2.position.val - true_x)**2).mean()), RMS lim_low = 1e-1 lim_high = 1e4 size = 15 plt.gray() plt.figure() - plt.imshow(data,norm=LogNorm()) + plt.imshow(data,norm=LogNorm(), vmin=lim_low,vmax=lim_high) cbar = plt.colorbar() cbar.set_label('intensity', size=size) plt.axis('off') @@ -103,31 +121,31 @@ if __name__ == '__main__': for i in range(energy1.s.val.shape[0]): ax0.plot(ift.exp(energy1.s).val[i], 'k-',alpha=(0.15/(energy1.s.val.shape[0])*i)) ax0.yaxis.set_label_position("right") - ax0.set_ylabel(r'$\alpha = 1.0$', size=size) + ax0.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.0$', size=size) ax0.set_ylim(lim_low ,lim_high) ax0.set_yscale("log") for i in range(energy1.s.val.shape[0]): ax1.plot(ift.exp(energy2.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax1.yaxis.set_label_position("right") - ax1.set_ylabel(r'$\alpha = 1.5$', size=size) + ax1.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.5$', size=size) for i in range(energy1.s.val.shape[0]): ax2.plot(ift.exp(energy3.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax2.yaxis.set_label_position("right") - ax2.set_ylabel(r'$\alpha = 3.0$', size=size) + ax2.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 3.0$', size=size) for i in range(energy1.s.val.shape[0]): ax3.plot(sextracted64[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax3.yaxis.set_label_position("right") - ax3.set_ylabel('sextractor'+ '\n'+ r'default', size=size) + ax3.set_ylabel('SExtractor'+ '\n'+ r'$64\times 64$', size=size) for i in range(energy1.s.val.shape[0]): ax4.plot((sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax4.yaxis.set_label_position("right") - ax5.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size) + ax4.set_ylabel(r'SExtractor'+'\n' + r'$8\times 8$', size=size) for i in range(energy1.s.val.shape[0]): ax5.plot((dcnn_diffuse)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax5.yaxis.set_label_position("right") - ax5.set_ylabel(r'DCAE', size=size) + ax5.set_ylabel(r'DAE', size=size) plt.savefig('1d_diffuse.pdf') @@ -138,28 +156,28 @@ if __name__ == '__main__': for i in range(energy1.s.val.shape[0]): ax0.plot(ift.exp(energy1.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax0.yaxis.set_label_position("right") - ax0.set_ylabel(r'$\alpha = 1.0$', size=size) + ax0.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.0$', size=size) ax0.set_ylim(lim_low ,lim_high) ax0.set_yscale("log") for i in range(energy1.s.val.shape[0]): ax1.plot(ift.exp(energy2.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax1.yaxis.set_label_position("right") - ax1.set_ylabel(r'$\alpha = 1.5$', size=size) + ax1.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.5$', size=size) for i in range(energy1.s.val.shape[0]): ax2.plot(ift.exp(energy3.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax2.yaxis.set_label_position("right") - ax2.set_ylabel(r'$\alpha = 3.0$', size=size) + ax2.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 3.0$', size=size) for i in range(energy1.s.val.shape[0]): ax3.plot((data-sextracted64)[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax3.yaxis.set_label_position("right") - ax3.set_ylabel(r'sextractor'+ '\n'+ r'default', size=size) + ax3.set_ylabel(r'SExtractor'+ '\n'+ r'$64\times 64$', size=size) for i in range(energy1.s.val.shape[0]): ax4.plot((data-sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) ax4.yaxis.set_label_position("right") - ax4.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size) + ax4.set_ylabel(r'SExtractor'+'\n' + r'$8 \times 8$', size=size) for i in range(energy1.s.val.shape[0]): ax5.plot((dcnn_points)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i)) - ax5.set_ylabel(r'DCAE', size=size) + ax5.set_ylabel(r'DAE', size=size) ax5.yaxis.set_label_position("right") ax0.set_yscale("log") @@ -206,26 +224,52 @@ if __name__ == '__main__': k_lengths=power1.domain[0].k_lengths # plt.plot(k_lengths, p_spec(k_lengths*1024.)*1024, 'k-', label='theoretical') plt.plot(k_lengths, power_d.val, 'k+-', label='data') - plt.plot(k_lengths, power1.val, 'k-', label=(r'$\alpha = 1.$'), alpha=0.15) - plt.plot(k_lengths, power2.val, 'k-', label=(r'$\alpha = 1.5$'),alpha=0.3) - plt.plot(k_lengths, power3.val, 'k-', label=(r'$\alpha = 3.0$'), alpha=0.6) - plt.plot(k_lengths, power_s.val, 'k:', label=('signal')) + plt.plot(k_lengths, power1.val, 'b-', label=(r'$\alpha = 1.$'), alpha=0.7) + plt.plot(k_lengths, power2.val, 'r-', label=(r'$\alpha = 1.5$'),alpha=0.7) + plt.plot(k_lengths, power3.val, 'g-', label=(r'$\alpha = 3.0$'), alpha=0.7) + plt.plot(k_lengths, power_s.val, 'k-', label=('signal')) # plt.plot(k_lengths, power_u.val, 'k:',label='point-like') plt.legend() plt.yscale('log') plt.xscale('log') - plt.title('power spectra',size=15) - plt.ylabel('power',size=15) - plt.xlabel('harmonic mode',size=15) + plt.title('power spectra',size=size) + plt.ylabel('power',size=size) + plt.xlabel('harmonic mode',size=size) plt.savefig('1d_power.pdf') - plt.figure() - plt.scatter(s.val, np.log(dcnn_diffuse), alpha = 0.2, c='k', label='DDCAE') - plt.scatter(s.val, energy2.s.val, alpha = 0.2, c='r', label = 'starblade') - plt.title('diffuse truth vs DDCAE and starblade') + subset = (np.random.randint(0,128,5000),np.random.randint(0,128,5000)) + plt.figure(figsize=(6,4)) + plt.scatter(np.exp(s.val)[subset], sextracted8[subset], alpha = .2, c='b',marker='.', label='SExtractor',s=1.5) + plt.scatter(np.exp(s.val)[subset], dcnn_diffuse[subset], alpha = .7, c='r',marker='.', label='DAE',s=1.5) + plt.scatter(np.exp(s.val)[subset], np.exp(energy2.s.val)[subset], alpha = .7, c='k', marker='.', label = 'starblade',s=1.5) + plt.ylabel('method',size=size) + plt.xlabel('truth',size=size) + plt.yscale('log') + plt.xscale('log') + plt.title('diffuse truth vs recovered',size=size) plt.legend() plt.savefig('scatter.pdf') - plt.close('all') + + plt.figure() + plt.imshow((np.sqrt(v.val))) + plt.colorbar() + plt.axis('off') + plt.title('estimated separation uncertainty', size=size) + cbar.set_label('standard deviation', size=size) + + plt.savefig('uncertainty.pdf') + subset = (np.random.randint(0,128,200),np.random.randint(0,128,200)) + + plt.figure() + plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(), + s =10000*np.log(np.sqrt(((s.val[subset]-energy2.s.val[subset])**2))).flatten(),marker='o', alpha=0.5 ) + plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(), + s=5000 * np.sqrt(v_s.val[subset]).flatten(), marker='o', alpha=0.5) + + plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(), + s =5000*np.sqrt(((s.val[subset]-np.log(dcnn_diffuse[subset]))**2)).flatten(),marker='o', alpha=0.2 ) + plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(), + s =5000*np.sqrt(((s.val[subset]-np.log(sextracted8[subset]))**2)).flatten(),marker='o', alpha=0.1 ) print "flux error energy1:", np.sqrt((((1 - ift.exp(s).val / ift.exp(energy1.s).val) ** 2)*d.val/d.sum()).mean()) print "flux error energy2:", np.sqrt((((1 - ift.exp(s).val / ift.exp(energy2.s).val) ** 2)*d.val/d.sum()).mean()) @@ -239,3 +283,13 @@ if __name__ == '__main__': print "class error back_size8:", np.sqrt(((1-sextracted8/ift.exp(s).val)**2).mean()) print "class error back_size64:", np.sqrt((((1-sextracted64/ift.exp(s).val)**2)).mean()) print "class error dcnn:", np.sqrt((((1-dcnn_diffuse/ift.exp(s).val)**2)).mean()) + + print "RMS energy1:", np.sqrt(((energy1.s.val - s.val) ** 2).mean()) + print "RMS energy2:", np.sqrt(((energy2.s.val - s.val) ** 2).mean()) + print "RMS energy3:", np.sqrt(((energy3.s.val - s.val) ** 2).mean()) + print "RMS a:", np.sqrt(((energy2.a.val - np.exp(brightness.val)/data) ** 2).mean()) + + print "RMS sextractor8:", np.sqrt(((np.log(sextracted8) - s.val) ** 2).mean()) + print "RMS sextractor64:", np.sqrt(((np.log(sextracted64) - s.val) ** 2).mean()) + print "RMS dcnn:", np.sqrt(((np.log(dcnn_diffuse) - s.val) ** 2).mean()) + diff --git a/paper/DDCAE.py b/paper/DDCAE.py new file mode 100644 index 0000000..02b60a6 --- /dev/null +++ b/paper/DDCAE.py @@ -0,0 +1,50 @@ +from keras.layers import Activation, Dense, Input +from keras.layers import Conv2D, Flatten +from keras.layers import Reshape, Conv2DTranspose, BatchNormalization, merge, Dropout +from keras.models import Model + + +def DDCAE_model(image_size): + input_shape = (image_size, image_size,1) + batch_size = 32 + kernel_size = 3 + latent_dim = 8 + # Encoder/Decoder number of CNN layers and filters per layer + layer_filters = [32,16] + + # Build the Autoencoder Model + # First build the Encoder Model + inputs = Input(shape=input_shape, name='encoder_input') + x = inputs + # Stack of Conv2D blocks + # Notes: + # 1) Use Batch Normalization before ReLU on deep networks + # 2) Use MaxPooling2D as alternative to strides>1 + # - faster but not as good as strides>1 + for filters in layer_filters: + x = Conv2D(filters=filters, + kernel_size=kernel_size, + strides=1, + activation='relu', + padding='same', + data_format='channels_last')(x) + # x = BatchNormalization()(x) + x = Dropout(0.1)(x) + + for filters in layer_filters[::-1]: + x = Conv2DTranspose(filters=filters, + kernel_size=kernel_size, + strides=1, + activation='relu', + padding='same', + data_format='channels_last')(x) + # x = BatchNormalization()(x) + x = Dropout(0.1)(x) + + x = Conv2DTranspose(filters=1, + kernel_size=kernel_size, + padding='same')(x) + + outputs = Activation('sigmoid', name='decoder_output')(x) + outputs = merge([inputs,outputs],mode='mul') + return Model(inputs, outputs, name='autoencoder') diff --git a/paper/hubble_separation.py b/paper/hubble_separation.py index 99d1bb4..bcfbdb4 100644 --- a/paper/hubble_separation.py +++ b/paper/hubble_separation.py @@ -9,32 +9,40 @@ from matplotlib.colors import LogNorm from mpl_toolkits.axes_grid1 import make_axes_locatable from mpl_toolkits.axes_grid1 import AxesGrid from astropy.io import fits - +from scipy.ndimage import zoom import starblade as sb - - +from DDCAE import DDCAE_model np.random.seed(42) if __name__ == '__main__': - path = 'demos/data/hst_05195_01_wfpc2_f702w_pc_sci.fits' + path = '../demos/data/hst_05195_01_wfpc2_f702w_pc_sci.fits' data = fits.open(path)[1].data - data = data.clip(min=0.001) + # data = data.clip(min=0.001) data = np.ndarray.astype(data, float) - alpha = 1.28 + #data = zoom(data, 0.25) + data = data.clip(min=0.001) + + hdu = fits.PrimaryHDU(data) + hdul = fits.HDUList([hdu]) + hdul.writeto('my_hubble_data.fits',overwrite=True) + + + alpha = 1.1 myEnergy = sb.build_starblade(data, alpha=alpha) - for i in range(10): - myEnergy = sb.starblade_iteration(myEnergy) - A = FFTSmoothingOperator(myEnergy.s.domain, sigma=2.) + for i in range(30): + myEnergy = sb.starblade_iteration(myEnergy, samples=5,newton_steps=30,cg_steps=10, sampling_steps=100) + print i + A = FFTSmoothingOperator(myEnergy.s.domain, sigma=.002) plt.magma() size = 15 vmin = data.min()+0.01 - vmax = 0.01*data.max() + vmax = data.max()*0.01 plt.figure() plt.title('diffuse emission', size=size) plt.axis('off') @@ -81,16 +89,16 @@ if __name__ == '__main__': plt.suptitle('zoomed in section', size=size) # fig.tight_layout() - vmin = data.min() + 0.0001 - vmax = 0.001 * data.max() - im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax)) + vmin_small = data.min() + 0.0001 + vmax_small = 0.001 * data.max() + im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small)) ax[0].set_title('data', size=15) ax[0].axis('off') - ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax)) + ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small)) ax[1].set_title('diffuse', size=15) ax[1].axis('off') - ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax)) + ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small)) ax[2].set_title('point-like', size=15) ax[2].axis('off') @@ -140,7 +148,7 @@ if __name__ == '__main__': k_lengths = power.domain[0].k_lengths plt.plot(k_lengths, power.val, 'k-', label='diffuse') - plt.plot(k_lengths, power_d.val, 'k:', label='data') + plt.plot(k_lengths, power_d.val, 'k+-', label='data') plt.legend() plt.yscale('log') plt.xscale('log') @@ -148,3 +156,55 @@ if __name__ == '__main__': plt.ylabel('power',size=15) plt.xlabel('harmonic mode',size=15) plt.savefig('hubble_log_power.pdf') + + DDCAE=DDCAE_model(None) + DDCAE.load_weights('DDCAC',by_name=True) + ddd = data.reshape([1,1000,1000,1]) + separation = DDCAE.predict(ddd) + plt.figure() + plt.title('DAE diffuse emission', size=size) + plt.axis('off') + ax = plt.gca() + im = ax.imshow(separation[0,:,:,0],norm=LogNorm(vmin=vmin, vmax=vmax)) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = plt.colorbar(im, cax=cax) + cbar.set_label('flux', size=size) + plt.tight_layout() + plt.savefig('DAE_hubble_diffuse.pdf') + plt.figure() + plt.title('DAE point-like emission', size=size) + plt.axis('off') + ax = plt.gca() + im = ax.imshow(A(ift.Field(myEnergy.position.domain,data-separation[0,:,:,0])).val,norm=LogNorm(vmin=vmin, vmax=vmax)) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = plt.colorbar(im, cax=cax) + cbar.set_label('flux', size=size) + plt.tight_layout() + plt.savefig('DAE_hubble_point_like.pdf') + + sextracted = fits.open('check_hubble64.fits')[0].data + plt.figure() + plt.title('SExtractor diffuse emission', size=size) + plt.axis('off') + ax = plt.gca() + im = ax.imshow(sextracted.clip(0.001),norm=LogNorm(vmin=vmin, vmax=vmax)) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = plt.colorbar(im, cax=cax) + cbar.set_label('flux', size=size) + plt.tight_layout() + plt.savefig('SExtractor_hubble_diffuse.pdf') + + plt.figure() + plt.title('SExtractor point-like emission', size=size) + plt.axis('off') + ax = plt.gca() + im = ax.imshow(A(ift.Field(myEnergy.position.domain,(data-sextracted))).val.clip(0.001),norm=LogNorm(vmin=vmin, vmax=vmax)) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = plt.colorbar(im, cax=cax) + cbar.set_label('flux', size=size) + plt.tight_layout() + plt.savefig('SExtractor_hubble_point_like.pdf') diff --git a/starblade/starblade_energy.py b/starblade/starblade_energy.py index 8000576..be7d540 100644 --- a/starblade/starblade_energy.py +++ b/starblade/starblade_energy.py @@ -17,7 +17,7 @@ # Starblade is being developed at the Max-Planck-Institut fuer Astrophysik from nifty4 import Energy, Field, log, exp, DiagonalOperator,\ - create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler + create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler, SamplingEnabler from nifty4.library import WienerFilterCurvature from nifty4.library.nonlinearities import PositiveTanh, Tanh @@ -56,6 +56,7 @@ class StarbladeEnergy(Energy): self.parameters = parameters self.inverter = parameters['inverter'] + self.sampling_inverter = parameters['sampling_inverter'] self.d = parameters['data'] self.FFT = parameters['FFT'] self.power_spectrum = parameters['power_spectrum'] @@ -115,10 +116,11 @@ class StarbladeEnergy(Energy): point = self.q * exp(-self.u) * self.u_p ** 2 # R = self.FFT.inverse * self.s_p # N = self.correlation - N_inv = DiagonalOperator(point + 1/self.var_x )#+ 2*self.a_p)) + O_x = DiagonalOperator(Field(self.position.domain,val=1./self.var_x)) + N_inv = DiagonalOperator(point )#+ 2*self.a_p)) R = ScalingOperator(1., point.domain) S_p = DiagonalOperator(self.s_p) my_S_inv = SandwichOperator.make(self.FFT.adjoint.inverse.adjoint * S_p, self.correlation.inverse) - curv = InversionEnabler(N_inv + my_S_inv, self.inverter) + curv = InversionEnabler(SamplingEnabler(my_S_inv+N_inv, O_x, self.sampling_inverter), self.inverter) return curv # return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter) diff --git a/starblade/sugar.py b/starblade/sugar.py index bf76b84..19a4fe7 100644 --- a/starblade/sugar.py +++ b/starblade/sugar.py @@ -23,8 +23,8 @@ import nifty4 as ift from .starblade_energy import StarbladeEnergy from .starblade_kl import StarbladeKL -def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iterations = 3, - manual_power_spectrum = None): +def build_starblade(data, alpha=1.5, q=1e-40, cg_steps=100, newton_steps = 3, + manual_power_spectrum = None, sampling_steps = 100): """ Setting up the StarbladeEnergy for the given data and parameters Parameters ---------- @@ -34,9 +34,9 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio The slope parameter of the point source prior (default: 1.5). q : float The cutoff parameter of the point source prior (default: 1e-40). - cg_iterations : int + cg_steps : int Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500). - newton_iterations : int + newton_steps : int Number of consecutive Newton steps within one algorithmic step.(default: 3) manual_power_spectrum : None, Field or callable Option to set a manual power spectrum which is kept constant during the separation. @@ -63,21 +63,24 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio initial_x = ift.Field(s_space, val=-1.) alpha = ift.Field(s_space, val=alpha) q = ift.Field(s_space, val=q) - ICI = ift.GradientNormController(iteration_limit=cg_iterations, + ICI = ift.GradientNormController(iteration_limit=cg_steps, tol_abs_gradnorm=1e-5) inverter = ift.ConjugateGradient(controller=ICI) + IC_samples = ift.GradientNormController(iteration_limit=sampling_steps, + tol_abs_gradnorm=1e-5) + sampling_inverter = ift.ConjugateGradient(controller=IC_samples) parameters = dict(data=data, power_spectrum=initial_spectrum, alpha=alpha, q=q, inverter=inverter, FFT=FFT, - newton_iterations=newton_iterations, update_power=update_power) + newton_iterations=newton_steps, sampling_inverter=sampling_inverter, update_power=update_power) Starblade = StarbladeEnergy(position=initial_x, parameters=parameters) return Starblade -def starblade_iteration(starblade, samples=3): +def starblade_iteration(starblade, samples=3, cg_steps=10, newton_steps=3, sampling_steps=100): """ Performing one Newton minimization step Parameters ---------- @@ -88,16 +91,26 @@ def starblade_iteration(starblade, samples=3): """ controller = ift.GradientNormController(name="Newton", tol_abs_gradnorm=1e-8, - iteration_limit=starblade.newton_iterations) + iteration_limit=newton_steps) minimizer = ift.RelaxedNewton(controller=controller) + ICI = ift.GradientNormController(iteration_limit=cg_steps, + tol_abs_gradnorm=1e-5) + inverter = ift.ConjugateGradient(controller=ICI) + IC_samples = ift.GradientNormController(iteration_limit=sampling_steps, + tol_abs_gradnorm=1e-5) + sampling_inverter = ift.ConjugateGradient(controller=IC_samples) # minimizer = ift.VL_BFGS(controller=controller) - energy = starblade + para = starblade.parameters + para['inverter'] = inverter + para['sampling_inverter'] = sampling_inverter + energy = StarbladeEnergy(starblade.position,parameters=para) + sample_list = [] for i in range(samples): sample = energy.curvature.inverse.draw_sample() sample_list.append(sample) if len(sample_list)>0: - energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters) + energy = StarbladeKL(starblade.position, samples=sample_list, parameters=energy.parameters) else: energy = starblade energy, convergence = minimizer(energy) @@ -140,9 +153,9 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, """ MultiStarblade = [] for i in range(data.shape[-1]): - starblade = build_starblade(data[...,i],alpha=alpha, q=q, - cg_iterations=cg_iterations, - newton_iterations=newton_iterations, + starblade = build_starblade(data[...,i], alpha=alpha, q=q, + cg_steps=cg_iterations, + newton_steps=newton_iterations, manual_power_spectrum = manual_power_spectrum) MultiStarblade.append(starblade) return MultiStarblade -- GitLab