From 8c9c655611fc949e544bc9b19f40f3bb7b442598 Mon Sep 17 00:00:00 2001
From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de>
Date: Mon, 18 Jun 2018 14:32:56 +0200
Subject: [PATCH] paper plots

---
 demos/demo.py                   |   2 +-
 demos/sextractor/compare_sex.py |   2 +-
 paper/1d_separation.py          | 124 +++++++++++++++++++++++---------
 paper/DDCAE.py                  |  50 +++++++++++++
 paper/hubble_separation.py      |  92 +++++++++++++++++++-----
 starblade/starblade_energy.py   |   8 ++-
 starblade/sugar.py              |  39 ++++++----
 7 files changed, 248 insertions(+), 69 deletions(-)
 create mode 100644 paper/DDCAE.py

diff --git a/demos/demo.py b/demos/demo.py
index 132b798..8791d9a 100644
--- a/demos/demo.py
+++ b/demos/demo.py
@@ -64,7 +64,7 @@ if __name__ == '__main__':
     plt.imsave('point_sex.png', (data_unmod - sex), vmax = lin_max,vmin =lin_min)
 
     alpha = 1.4
-    Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=10)
+    Starblade = sb.build_starblade(data, alpha=alpha, cg_steps=100, newton_steps=10)
     for i in range(1000):
         Starblade = sb.starblade_iteration(Starblade, samples=5)
 
diff --git a/demos/sextractor/compare_sex.py b/demos/sextractor/compare_sex.py
index fa62758..0d38cfb 100644
--- a/demos/sextractor/compare_sex.py
+++ b/demos/sextractor/compare_sex.py
@@ -70,7 +70,7 @@ if __name__ == '__main__':
     plt.imsave('log_point_sex.png', np.log((data_unmod - sex).clip(min=0.0001)), vmax = vmax,vmin =vmin)
 
     alpha = 1.4
-    Starblade = sb.build_starblade(data, alpha=alpha,cg_iterations=100,newton_iterations=3)
+    Starblade = sb.build_starblade(data, alpha=alpha, cg_steps=100, newton_steps=3)
     for i in range(10):
         Starblade = sb.starblade_iteration(Starblade, samples=5)
 
diff --git a/paper/1d_separation.py b/paper/1d_separation.py
index 7553c72..efbe2f1 100644
--- a/paper/1d_separation.py
+++ b/paper/1d_separation.py
@@ -1,5 +1,8 @@
 import nifty4 as ift
 import numpy as np
+import matplotlib as mpl
+mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}'] #for \text command
+
 from matplotlib import rc
 rc('font',**{'family':'serif','serif':['Palatino']})
 rc('text', usetex=True)
@@ -43,37 +46,48 @@ if __name__ == '__main__':
     dcnn_diffuse = dcnn_diffuse.clip(0.001)
     dcnn_points = dcnn_points.clip(0.001)
 
-    energy1 = sb.build_starblade(data,1.0, newton_iterations=200, cg_iterations=10,
+    energy1 = sb.build_starblade(data,1.0, newton_steps=200, cg_steps=5,
                                  q=q)#, manual_power_spectrum= p_spec)
-    energy2 = sb.build_starblade(data,1.5,  newton_iterations=200, cg_iterations=10,
+    energy2 = sb.build_starblade(data,1.5,  newton_steps=200, cg_steps=5,
                                  q=q)#, manual_power_spectrum= p_spec)
-    energy3 = sb.build_starblade(data,3., newton_iterations=200, cg_iterations=10,
+    energy3 = sb.build_starblade(data,3., newton_steps=200, cg_steps=5,
                                  q=q)#, manual_power_spectrum= p_spec)
     # ift.extra.check_value_gradient_consistency(energy1, tol=1e-3)
-    for i in range(5):
-        energy1 = sb.starblade_iteration(energy1, samples=3)
-        energy2 = sb.starblade_iteration(energy2, samples=3)
-        energy3 = sb.starblade_iteration(energy3, samples=3)
+    for i in range(30):
+
+        energy1 = sb.starblade_iteration(energy1, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000)
+        energy2 = sb.starblade_iteration(energy2, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000)
+        energy3 = sb.starblade_iteration(energy3, samples=1+i, cg_steps=10, newton_steps=100, sampling_steps=1000)
         print "error energy1:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy1.s).val) ** 2).mean())
         print "error energy2:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy2.s).val) ** 2).mean())
         print "error energy3:", np.sqrt(((1 - ift.exp(s).val/ift.exp(energy3.s).val) ** 2).mean())
+    # energy2 = sb.starblade_iteration(energy2, samples=10 , cg_steps=1000, newton_steps=30)
 
     samples = []
-    n=30
+    n=1000
+
     for i in range(n):
         samples.append(energy2.curvature.inverse.draw_sample())
+        print samples[i].var()
     m = 0
     v = 0
     s_s=0
     v_s = 0
     pos_tanh = ift.library.nonlinearities.PositiveTanh()
     p=0
+    RMS = 0
+    RMS_s = 0
+
     for sample in samples:
         a_s = pos_tanh(energy2.position+sample)
         m +=  a_s
         v += a_s**2
-        s_s += ift.log(d*(1-a_s))
-        v_s += ift.log(d*(1-a_s))**2
+        sam_s = ift.log(d*(1-a_s))
+        RMS_s += np.sqrt(((sam_s - energy2.s)**2).mean())
+        s_s += sam_s
+        v_s += (sam_s)**2
+        print  np.sqrt(((sample**2 ).mean()))
+        RMS += np.sqrt(((sample**2 ).mean()))
         p += ift.power_analyze(FFT.adjoint(s))
     m /= n
     v /= n
@@ -82,14 +96,18 @@ if __name__ == '__main__':
     v_s /= n
     v_s -= s_s**2
     p /= n
+    RMS /=n
+    RMS_s /=n
+    true_x = np.arctanh(np.exp(brightness.val) / data * 2 - 1)
 
+    print "RMS x:", np.sqrt(((energy2.position.val - true_x)**2).mean()), RMS
     lim_low = 1e-1
     lim_high = 1e4
     size = 15
 
     plt.gray()
     plt.figure()
-    plt.imshow(data,norm=LogNorm())
+    plt.imshow(data,norm=LogNorm(), vmin=lim_low,vmax=lim_high)
     cbar = plt.colorbar()
     cbar.set_label('intensity', size=size)
     plt.axis('off')
@@ -103,31 +121,31 @@ if __name__ == '__main__':
     for i in range(energy1.s.val.shape[0]):
         ax0.plot(ift.exp(energy1.s).val[i], 'k-',alpha=(0.15/(energy1.s.val.shape[0])*i))
     ax0.yaxis.set_label_position("right")
-    ax0.set_ylabel(r'$\alpha = 1.0$', size=size)
+    ax0.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.0$', size=size)
     ax0.set_ylim(lim_low ,lim_high)
     ax0.set_yscale("log")
 
     for i in range(energy1.s.val.shape[0]):
         ax1.plot(ift.exp(energy2.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax1.yaxis.set_label_position("right")
-    ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
+    ax1.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.5$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax2.plot(ift.exp(energy3.s).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax2.yaxis.set_label_position("right")
-    ax2.set_ylabel(r'$\alpha = 3.0$', size=size)
+    ax2.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 3.0$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax3.plot(sextracted64[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax3.yaxis.set_label_position("right")
-    ax3.set_ylabel('sextractor'+ '\n'+ r'default', size=size)
+    ax3.set_ylabel('SExtractor'+ '\n'+ r'$64\times 64$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax4.plot((sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax4.yaxis.set_label_position("right")
-    ax5.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size)
+    ax4.set_ylabel(r'SExtractor'+'\n' + r'$8\times 8$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax5.plot((dcnn_diffuse)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax5.yaxis.set_label_position("right")
 
-    ax5.set_ylabel(r'DCAE', size=size)
+    ax5.set_ylabel(r'DAE', size=size)
 
     plt.savefig('1d_diffuse.pdf')
 
@@ -138,28 +156,28 @@ if __name__ == '__main__':
     for i in range(energy1.s.val.shape[0]):
         ax0.plot(ift.exp(energy1.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax0.yaxis.set_label_position("right")
-    ax0.set_ylabel(r'$\alpha = 1.0$', size=size)
+    ax0.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.0$', size=size)
     ax0.set_ylim(lim_low ,lim_high)
     ax0.set_yscale("log")
     for i in range(energy1.s.val.shape[0]):
         ax1.plot(ift.exp(energy2.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax1.yaxis.set_label_position("right")
-    ax1.set_ylabel(r'$\alpha = 1.5$', size=size)
+    ax1.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 1.5$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax2.plot(ift.exp(energy3.u).val[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax2.yaxis.set_label_position("right")
-    ax2.set_ylabel(r'$\alpha = 3.0$', size=size)
+    ax2.set_ylabel(r'\texttt{starblade}'+ '\n'+r'$\alpha = 3.0$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax3.plot((data-sextracted64)[i], 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax3.yaxis.set_label_position("right")
-    ax3.set_ylabel(r'sextractor'+ '\n'+ r'default', size=size)
+    ax3.set_ylabel(r'SExtractor'+ '\n'+ r'$64\times 64$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax4.plot((data-sextracted8)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
     ax4.yaxis.set_label_position("right")
-    ax4.set_ylabel(r'sextractor'+'\n' + r'\textsf{BACK\_SIZE}'+r'$=8$', size=size)
+    ax4.set_ylabel(r'SExtractor'+'\n' + r'$8 \times 8$', size=size)
     for i in range(energy1.s.val.shape[0]):
         ax5.plot((dcnn_points)[i].clip(0.0001), 'k-', alpha=(0.15 / (energy1.s.val.shape[0]) * i))
-    ax5.set_ylabel(r'DCAE', size=size)
+    ax5.set_ylabel(r'DAE', size=size)
     ax5.yaxis.set_label_position("right")
 
     ax0.set_yscale("log")
@@ -206,26 +224,52 @@ if __name__ == '__main__':
     k_lengths=power1.domain[0].k_lengths
     # plt.plot(k_lengths, p_spec(k_lengths*1024.)*1024, 'k-', label='theoretical')
     plt.plot(k_lengths, power_d.val, 'k+-', label='data')
-    plt.plot(k_lengths, power1.val, 'k-', label=(r'$\alpha = 1.$'), alpha=0.15)
-    plt.plot(k_lengths, power2.val, 'k-', label=(r'$\alpha = 1.5$'),alpha=0.3)
-    plt.plot(k_lengths, power3.val, 'k-', label=(r'$\alpha = 3.0$'), alpha=0.6)
-    plt.plot(k_lengths, power_s.val, 'k:', label=('signal'))
+    plt.plot(k_lengths, power1.val, 'b-', label=(r'$\alpha = 1.$'), alpha=0.7)
+    plt.plot(k_lengths, power2.val, 'r-', label=(r'$\alpha = 1.5$'),alpha=0.7)
+    plt.plot(k_lengths, power3.val, 'g-', label=(r'$\alpha = 3.0$'), alpha=0.7)
+    plt.plot(k_lengths, power_s.val, 'k-', label=('signal'))
 
     # plt.plot(k_lengths, power_u.val, 'k:',label='point-like')
     plt.legend()
     plt.yscale('log')
     plt.xscale('log')
-    plt.title('power spectra',size=15)
-    plt.ylabel('power',size=15)
-    plt.xlabel('harmonic mode',size=15)
+    plt.title('power spectra',size=size)
+    plt.ylabel('power',size=size)
+    plt.xlabel('harmonic mode',size=size)
     plt.savefig('1d_power.pdf')
-    plt.figure()
-    plt.scatter(s.val, np.log(dcnn_diffuse), alpha = 0.2, c='k', label='DDCAE')
-    plt.scatter(s.val, energy2.s.val, alpha = 0.2, c='r', label = 'starblade')
-    plt.title('diffuse truth vs DDCAE and starblade')
+    subset = (np.random.randint(0,128,5000),np.random.randint(0,128,5000))
+    plt.figure(figsize=(6,4))
+    plt.scatter(np.exp(s.val)[subset], sextracted8[subset], alpha = .2, c='b',marker='.', label='SExtractor',s=1.5)
+    plt.scatter(np.exp(s.val)[subset], dcnn_diffuse[subset], alpha = .7, c='r',marker='.', label='DAE',s=1.5)
+    plt.scatter(np.exp(s.val)[subset], np.exp(energy2.s.val)[subset], alpha = .7, c='k', marker='.', label = 'starblade',s=1.5)
+    plt.ylabel('method',size=size)
+    plt.xlabel('truth',size=size)
+    plt.yscale('log')
+    plt.xscale('log')
+    plt.title('diffuse truth vs recovered',size=size)
     plt.legend()
     plt.savefig('scatter.pdf')
-    plt.close('all')
+
+    plt.figure()
+    plt.imshow((np.sqrt(v.val)))
+    plt.colorbar()
+    plt.axis('off')
+    plt.title('estimated separation uncertainty', size=size)
+    cbar.set_label('standard deviation', size=size)
+
+    plt.savefig('uncertainty.pdf')
+    subset = (np.random.randint(0,128,200),np.random.randint(0,128,200))
+
+    plt.figure()
+    plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
+                s =10000*np.log(np.sqrt(((s.val[subset]-energy2.s.val[subset])**2))).flatten(),marker='o', alpha=0.5 )
+    plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
+                s=5000 * np.sqrt(v_s.val[subset]).flatten(), marker='o', alpha=0.5)
+
+    plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
+                s =5000*np.sqrt(((s.val[subset]-np.log(dcnn_diffuse[subset]))**2)).flatten(),marker='o', alpha=0.2 )
+    plt.scatter(s.val[subset].flatten(), (brightness.val[subset]).flatten(),
+                s =5000*np.sqrt(((s.val[subset]-np.log(sextracted8[subset]))**2)).flatten(),marker='o', alpha=0.1 )
 
     print "flux error energy1:", np.sqrt((((1 - ift.exp(s).val / ift.exp(energy1.s).val) ** 2)*d.val/d.sum()).mean())
     print "flux error energy2:", np.sqrt((((1 - ift.exp(s).val / ift.exp(energy2.s).val) ** 2)*d.val/d.sum()).mean())
@@ -239,3 +283,13 @@ if __name__ == '__main__':
     print "class error back_size8:", np.sqrt(((1-sextracted8/ift.exp(s).val)**2).mean())
     print "class error back_size64:", np.sqrt((((1-sextracted64/ift.exp(s).val)**2)).mean())
     print "class error dcnn:", np.sqrt((((1-dcnn_diffuse/ift.exp(s).val)**2)).mean())
+
+    print "RMS energy1:", np.sqrt(((energy1.s.val -  s.val) ** 2).mean())
+    print "RMS energy2:", np.sqrt(((energy2.s.val -  s.val) ** 2).mean())
+    print "RMS energy3:", np.sqrt(((energy3.s.val -  s.val) ** 2).mean())
+    print "RMS a:", np.sqrt(((energy2.a.val -  np.exp(brightness.val)/data) ** 2).mean())
+
+    print "RMS sextractor8:", np.sqrt(((np.log(sextracted8) -  s.val) ** 2).mean())
+    print "RMS sextractor64:", np.sqrt(((np.log(sextracted64) -  s.val) ** 2).mean())
+    print "RMS dcnn:", np.sqrt(((np.log(dcnn_diffuse) -  s.val) ** 2).mean())
+
diff --git a/paper/DDCAE.py b/paper/DDCAE.py
new file mode 100644
index 0000000..02b60a6
--- /dev/null
+++ b/paper/DDCAE.py
@@ -0,0 +1,50 @@
+from keras.layers import Activation, Dense, Input
+from keras.layers import Conv2D, Flatten
+from keras.layers import Reshape, Conv2DTranspose, BatchNormalization, merge, Dropout
+from keras.models import Model
+
+
+def DDCAE_model(image_size):
+    input_shape = (image_size, image_size,1)
+    batch_size = 32
+    kernel_size = 3
+    latent_dim = 8
+    # Encoder/Decoder number of CNN layers and filters per layer
+    layer_filters = [32,16]
+
+    # Build the Autoencoder Model
+    # First build the Encoder Model
+    inputs = Input(shape=input_shape, name='encoder_input')
+    x = inputs
+    # Stack of Conv2D blocks
+    # Notes:
+    # 1) Use Batch Normalization before ReLU on deep networks
+    # 2) Use MaxPooling2D as alternative to strides>1
+    # - faster but not as good as strides>1
+    for filters in layer_filters:
+        x = Conv2D(filters=filters,
+                   kernel_size=kernel_size,
+                   strides=1,
+                   activation='relu',
+                   padding='same',
+                   data_format='channels_last')(x)
+        # x = BatchNormalization()(x)
+        x = Dropout(0.1)(x)
+
+    for filters in layer_filters[::-1]:
+        x = Conv2DTranspose(filters=filters,
+                            kernel_size=kernel_size,
+                            strides=1,
+                            activation='relu',
+                            padding='same',
+                            data_format='channels_last')(x)
+        # x = BatchNormalization()(x)
+        x = Dropout(0.1)(x)
+
+    x = Conv2DTranspose(filters=1,
+                        kernel_size=kernel_size,
+                        padding='same')(x)
+
+    outputs = Activation('sigmoid', name='decoder_output')(x)
+    outputs = merge([inputs,outputs],mode='mul')
+    return  Model(inputs, outputs, name='autoencoder')
diff --git a/paper/hubble_separation.py b/paper/hubble_separation.py
index 99d1bb4..bcfbdb4 100644
--- a/paper/hubble_separation.py
+++ b/paper/hubble_separation.py
@@ -9,32 +9,40 @@ from matplotlib.colors import LogNorm
 from mpl_toolkits.axes_grid1 import make_axes_locatable
 from mpl_toolkits.axes_grid1 import AxesGrid
 from astropy.io import fits
-
+from scipy.ndimage import zoom
 import starblade as sb
-
-
+from DDCAE import DDCAE_model
 
 
 np.random.seed(42)
 if __name__ == '__main__':
-    path = 'demos/data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
+    path = '../demos/data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
     data = fits.open(path)[1].data
 
-    data = data.clip(min=0.001)
+    # data = data.clip(min=0.001)
 
     data = np.ndarray.astype(data, float)
-    alpha = 1.28
+    #data = zoom(data, 0.25)
+    data = data.clip(min=0.001)
+
+    hdu = fits.PrimaryHDU(data)
+    hdul = fits.HDUList([hdu])
+    hdul.writeto('my_hubble_data.fits',overwrite=True)
+
+
+    alpha = 1.1
 
 
 
     myEnergy = sb.build_starblade(data, alpha=alpha)
-    for i in range(10):
-        myEnergy = sb.starblade_iteration(myEnergy)
-    A = FFTSmoothingOperator(myEnergy.s.domain, sigma=2.)
+    for i in range(30):
+        myEnergy = sb.starblade_iteration(myEnergy, samples=5,newton_steps=30,cg_steps=10, sampling_steps=100)
+        print i
+    A = FFTSmoothingOperator(myEnergy.s.domain, sigma=.002)
     plt.magma()
     size = 15
     vmin = data.min()+0.01
-    vmax = 0.01*data.max()
+    vmax = data.max()*0.01
     plt.figure()
     plt.title('diffuse emission', size=size)
     plt.axis('off')
@@ -81,16 +89,16 @@ if __name__ == '__main__':
 
     plt.suptitle('zoomed in section', size=size)
     # fig.tight_layout()
-    vmin = data.min() + 0.0001
-    vmax = 0.001 * data.max()
-    im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
+    vmin_small = data.min() + 0.0001
+    vmax_small = 0.001 * data.max()
+    im = ax[0].imshow(data[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small))
     ax[0].set_title('data', size=15)
     ax[0].axis('off')
-    ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
+    ax[1].imshow(exp(myEnergy.s).val[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small))
     ax[1].set_title('diffuse', size=15)
     ax[1].axis('off')
 
-    ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin, vmax=vmax))
+    ax[2].imshow(exp(myEnergy.u).val[600:700, 650:750], norm=LogNorm(vmin=vmin_small, vmax=vmax_small))
     ax[2].set_title('point-like', size=15)
     ax[2].axis('off')
 
@@ -140,7 +148,7 @@ if __name__ == '__main__':
 
     k_lengths = power.domain[0].k_lengths
     plt.plot(k_lengths, power.val, 'k-', label='diffuse')
-    plt.plot(k_lengths, power_d.val, 'k:', label='data')
+    plt.plot(k_lengths, power_d.val, 'k+-', label='data')
     plt.legend()
     plt.yscale('log')
     plt.xscale('log')
@@ -148,3 +156,55 @@ if __name__ == '__main__':
     plt.ylabel('power',size=15)
     plt.xlabel('harmonic mode',size=15)
     plt.savefig('hubble_log_power.pdf')
+
+    DDCAE=DDCAE_model(None)
+    DDCAE.load_weights('DDCAC',by_name=True)
+    ddd = data.reshape([1,1000,1000,1])
+    separation = DDCAE.predict(ddd)
+    plt.figure()
+    plt.title('DAE diffuse emission', size=size)
+    plt.axis('off')
+    ax = plt.gca()
+    im = ax.imshow(separation[0,:,:,0],norm=LogNorm(vmin=vmin, vmax=vmax))
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes("right", size="5%", pad=0.05)
+    cbar = plt.colorbar(im, cax=cax)
+    cbar.set_label('flux', size=size)
+    plt.tight_layout()
+    plt.savefig('DAE_hubble_diffuse.pdf')
+    plt.figure()
+    plt.title('DAE point-like emission', size=size)
+    plt.axis('off')
+    ax = plt.gca()
+    im = ax.imshow(A(ift.Field(myEnergy.position.domain,data-separation[0,:,:,0])).val,norm=LogNorm(vmin=vmin, vmax=vmax))
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes("right", size="5%", pad=0.05)
+    cbar = plt.colorbar(im, cax=cax)
+    cbar.set_label('flux', size=size)
+    plt.tight_layout()
+    plt.savefig('DAE_hubble_point_like.pdf')
+
+    sextracted = fits.open('check_hubble64.fits')[0].data
+    plt.figure()
+    plt.title('SExtractor diffuse emission', size=size)
+    plt.axis('off')
+    ax = plt.gca()
+    im = ax.imshow(sextracted.clip(0.001),norm=LogNorm(vmin=vmin, vmax=vmax))
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes("right", size="5%", pad=0.05)
+    cbar = plt.colorbar(im, cax=cax)
+    cbar.set_label('flux', size=size)
+    plt.tight_layout()
+    plt.savefig('SExtractor_hubble_diffuse.pdf')
+
+    plt.figure()
+    plt.title('SExtractor point-like emission', size=size)
+    plt.axis('off')
+    ax = plt.gca()
+    im = ax.imshow(A(ift.Field(myEnergy.position.domain,(data-sextracted))).val.clip(0.001),norm=LogNorm(vmin=vmin, vmax=vmax))
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes("right", size="5%", pad=0.05)
+    cbar = plt.colorbar(im, cax=cax)
+    cbar.set_label('flux', size=size)
+    plt.tight_layout()
+    plt.savefig('SExtractor_hubble_point_like.pdf')
diff --git a/starblade/starblade_energy.py b/starblade/starblade_energy.py
index 8000576..be7d540 100644
--- a/starblade/starblade_energy.py
+++ b/starblade/starblade_energy.py
@@ -17,7 +17,7 @@
 # Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
 
 from nifty4 import Energy, Field, log, exp, DiagonalOperator,\
-    create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler
+    create_power_operator, SandwichOperator, ScalingOperator, InversionEnabler, SamplingEnabler
 from nifty4.library import WienerFilterCurvature
 from nifty4.library.nonlinearities import PositiveTanh, Tanh
 
@@ -56,6 +56,7 @@ class StarbladeEnergy(Energy):
 
         self.parameters = parameters
         self.inverter = parameters['inverter']
+        self.sampling_inverter = parameters['sampling_inverter']
         self.d = parameters['data']
         self.FFT = parameters['FFT']
         self.power_spectrum = parameters['power_spectrum']
@@ -115,10 +116,11 @@ class StarbladeEnergy(Energy):
         point = self.q * exp(-self.u) * self.u_p ** 2
         # R = self.FFT.inverse * self.s_p
         # N = self.correlation
-        N_inv = DiagonalOperator(point + 1/self.var_x )#+ 2*self.a_p))
+        O_x = DiagonalOperator(Field(self.position.domain,val=1./self.var_x))
+        N_inv = DiagonalOperator(point )#+ 2*self.a_p))
         R = ScalingOperator(1., point.domain)
         S_p = DiagonalOperator(self.s_p)
         my_S_inv = SandwichOperator.make(self.FFT.adjoint.inverse.adjoint * S_p, self.correlation.inverse)
-        curv = InversionEnabler(N_inv + my_S_inv, self.inverter)
+        curv = InversionEnabler(SamplingEnabler(my_S_inv+N_inv, O_x, self.sampling_inverter), self.inverter)
         return curv
         # return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)
diff --git a/starblade/sugar.py b/starblade/sugar.py
index bf76b84..19a4fe7 100644
--- a/starblade/sugar.py
+++ b/starblade/sugar.py
@@ -23,8 +23,8 @@ import nifty4 as ift
 from .starblade_energy import StarbladeEnergy
 from .starblade_kl import StarbladeKL
 
-def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iterations = 3,
-                    manual_power_spectrum = None):
+def build_starblade(data, alpha=1.5, q=1e-40, cg_steps=100, newton_steps = 3,
+                    manual_power_spectrum = None, sampling_steps = 100):
     """ Setting up the StarbladeEnergy for the given data and parameters
     Parameters
     ----------
@@ -34,9 +34,9 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio
         The slope parameter of the point source prior (default: 1.5).
     q : float
         The cutoff parameter of the point source prior (default: 1e-40).
-    cg_iterations : int
+    cg_steps : int
         Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
-    newton_iterations : int
+    newton_steps : int
         Number of consecutive Newton steps within one algorithmic step.(default: 3)
     manual_power_spectrum : None, Field or callable
         Option to set a manual power spectrum which is kept constant during the separation.
@@ -63,21 +63,24 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iteratio
     initial_x = ift.Field(s_space, val=-1.)
     alpha = ift.Field(s_space, val=alpha)
     q = ift.Field(s_space, val=q)
-    ICI = ift.GradientNormController(iteration_limit=cg_iterations,
+    ICI = ift.GradientNormController(iteration_limit=cg_steps,
                                      tol_abs_gradnorm=1e-5)
     inverter = ift.ConjugateGradient(controller=ICI)
+    IC_samples =  ift.GradientNormController(iteration_limit=sampling_steps,
+                                     tol_abs_gradnorm=1e-5)
+    sampling_inverter = ift.ConjugateGradient(controller=IC_samples)
 
     parameters = dict(data=data, power_spectrum=initial_spectrum,
                       alpha=alpha, q=q,
                       inverter=inverter, FFT=FFT,
-                      newton_iterations=newton_iterations, update_power=update_power)
+                      newton_iterations=newton_steps, sampling_inverter=sampling_inverter, update_power=update_power)
     Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
 
 
     return Starblade
 
 
-def starblade_iteration(starblade, samples=3):
+def starblade_iteration(starblade, samples=3, cg_steps=10, newton_steps=3, sampling_steps=100):
     """ Performing one Newton minimization step
     Parameters
     ----------
@@ -88,16 +91,26 @@ def starblade_iteration(starblade, samples=3):
     """
     controller = ift.GradientNormController(name="Newton",
                                             tol_abs_gradnorm=1e-8,
-                                            iteration_limit=starblade.newton_iterations)
+                                            iteration_limit=newton_steps)
     minimizer = ift.RelaxedNewton(controller=controller)
+    ICI = ift.GradientNormController(iteration_limit=cg_steps,
+                                     tol_abs_gradnorm=1e-5)
+    inverter = ift.ConjugateGradient(controller=ICI)
+    IC_samples =  ift.GradientNormController(iteration_limit=sampling_steps,
+                                     tol_abs_gradnorm=1e-5)
+    sampling_inverter = ift.ConjugateGradient(controller=IC_samples)
     # minimizer = ift.VL_BFGS(controller=controller)
-    energy = starblade
+    para = starblade.parameters
+    para['inverter'] = inverter
+    para['sampling_inverter'] = sampling_inverter
+    energy = StarbladeEnergy(starblade.position,parameters=para)
+
     sample_list = []
     for i in range(samples):
         sample = energy.curvature.inverse.draw_sample()
         sample_list.append(sample)
     if len(sample_list)>0:
-        energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
+        energy = StarbladeKL(starblade.position, samples=sample_list, parameters=energy.parameters)
     else:
         energy = starblade
     energy, convergence = minimizer(energy)
@@ -140,9 +153,9 @@ def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500,
     """
     MultiStarblade = []
     for i in range(data.shape[-1]):
-        starblade = build_starblade(data[...,i],alpha=alpha, q=q,
-                                    cg_iterations=cg_iterations,
-                                    newton_iterations=newton_iterations,
+        starblade = build_starblade(data[...,i], alpha=alpha, q=q,
+                                    cg_steps=cg_iterations,
+                                    newton_steps=newton_iterations,
                                     manual_power_spectrum = manual_power_spectrum)
         MultiStarblade.append(starblade)
     return MultiStarblade
-- 
GitLab