From e091563a73628174bdd2434ff72668dbb4e56b67 Mon Sep 17 00:00:00 2001
From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de>
Date: Wed, 25 Apr 2018 16:22:00 +0200
Subject: [PATCH] enabling KL

---
 demos/KL_demo.py          | 138 ++++++++++++++++++++++++++++++++++++++
 demos/clipping.py         |  91 +++++++++++++++++++++++++
 demos/demo.py             |  18 +++--
 starblade/__init__.py     |   2 +
 starblade/starblade_kl.py |  60 +++++++++++++++++
 starblade/sugar.py        |  38 ++++++++---
 6 files changed, 333 insertions(+), 14 deletions(-)
 create mode 100644 demos/KL_demo.py
 create mode 100644 demos/clipping.py
 create mode 100644 starblade/starblade_kl.py

diff --git a/demos/KL_demo.py b/demos/KL_demo.py
new file mode 100644
index 0000000..657ef6c
--- /dev/null
+++ b/demos/KL_demo.py
@@ -0,0 +1,138 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2017-2018 Max-Planck-Society
+# Author: Jakob Knollmueller
+#
+# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
+
+import numpy as np
+from astropy.io import fits
+from matplotlib import pyplot as plt
+from multiprocessing import Pool
+
+import nifty4 as ift
+from nifty4.library.nonlinearities import PositiveTanh
+
+
+import starblade as sb
+from starblade.starblade_energy import StarbladeEnergy
+from starblade.starblade_kl import StarbladeKL
+
+def power_update(KL_energy):
+    power = 0.
+    for energy in KL_energy.energy_list:
+        power += ift.power_analyze(FFT.inverse_times(energy.s),
+                                             binbounds=p_space.binbounds)
+    power /= len(KL_energy.energy_list)
+    return power
+
+if __name__ == '__main__':
+    #specifying location of the input file:
+    path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
+    path = 'data/frame-u-006174-2-0094.fits'
+    # path = 'data/frame-g-002821-6-0141.fits'
+    path = 'data/frame-g-007812-6-0100.fits'
+    path = 'data/frame-i-004874-3-0692.fits'
+
+    # data = fits.open(path)[1].data
+    data = fits.open(path)[0].data#[1000:,1250:]
+    data -= data.min() - 0.001
+    # data = np.exp(2*(1.-plt.imread('data/sdss.png').T[0]))
+    # data = (plt.imread('data/m51_3.jpg').T[0])
+    # data = (plt.imread('data/12_FBP.png').T[0])
+
+
+    #
+    # data = data.clip(min=0.001)
+
+
+    data = np.ndarray.astype(data, float)
+    vmin = np.log(data.min()+0.01)
+    vmax = np.log(data.max())
+    plt.imsave('data.png', np.log(data))
+    postanh=PositiveTanh()
+    alpha = 1.5
+    s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
+    h_space = s_space.get_default_codomain()
+    data = ift.Field(s_space,val=data)
+    FFT = ift.FFTOperator(h_space, target=s_space)
+    binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
+    p_space = ift.PowerSpace(h_space, binbounds=binbounds)
+    initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)),
+                                             binbounds=p_space.binbounds)
+    initial_spectrum /= (p_space.k_lengths+1.)**4
+    update_power = True
+
+    initial_x = ift.Field(s_space, val=-1.)
+    alpha = ift.Field(s_space, val=alpha)
+    q = ift.Field(s_space, val=1e-30)
+    ICI = ift.GradientNormController(iteration_limit=100,
+                                     tol_abs_gradnorm=1e-3)
+    inverter = ift.ConjugateGradient(controller=ICI)
+
+    parameters = dict(data=data, power_spectrum=initial_spectrum,
+                      alpha=alpha, q=q,
+                      inverter=inverter, FFT=FFT,
+                      newton_iterations=5, update_power=update_power)
+    current_x = initial_x
+    for i in range(10):
+        Starblade = StarbladeEnergy(position=current_x, parameters=parameters)
+        samples = []
+        for i in range(3):
+            sample = Starblade.curvature.inverse.draw_sample()
+            samples.append(sample)
+        problem = StarbladeKL(current_x, samples,parameters)
+
+        controller = ift.GradientNormController(name="Newton",
+                                                tol_abs_gradnorm=1e-5,
+                                                iteration_limit=5)
+        minimizer = ift.RelaxedNewton(controller=controller)
+        problem, convergence = minimizer(problem)
+        current_x = problem.position
+        parameters['power_spectrum'] = power_update(problem)
+        Starblade = StarbladeEnergy(position=current_x, parameters=parameters)
+
+    # Starblade = sb.build_starblade(data, alpha=alpha)
+    # for i in range(10):
+    #     Starblade = sb.starblade_iteration(Starblade)
+    #
+    #     #plotting on logarithmic scale
+        plt.imsave('diffuse_component.png', (Starblade.s).val,vmin=vmin, vmax=vmax)
+        plt.imsave('pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax)
+    Starblade = StarbladeEnergy(position=current_x, parameters=parameters)
+    var = 0.
+    mean = 0
+    samps = 30
+    for i in range(samps):
+        sam = postanh(Starblade.position+Starblade.curvature.inverse.draw_sample())
+        mean += sam
+        var += sam**2
+
+    var /= samps
+    mean /= samps
+    var -= mean**2
+    mask = ift.sqrt(var) < 0.01 +0.
+    plt.imsave('masked_points.png', mask.val * Starblade.u.val, vmin=vmin, vmax=vmax)
+    plt.imsave('masked_diffuse.png', mask.val * Starblade.s.val)
+
+    plt.imsave('std.png', np.log(np.sqrt(var.val)*data.val), vmin=-3.3)
+    #     plt.figure()
+    #     k_lenghts = Starblade.power_spectrum.domain[0].k_lengths
+    #     plt.plot(k_lenghts, Starblade.power_spectrum.val)
+    #     plt.title('power spectrum')
+    #     plt.yscale('log')
+    #     plt.xscale('log')
+    #     plt.ylabel('power')
+    #     plt.xscale('harmonic mode')
+    #     plt.savefig('power_spectrum.png')
diff --git a/demos/clipping.py b/demos/clipping.py
new file mode 100644
index 0000000..2e0c72a
--- /dev/null
+++ b/demos/clipping.py
@@ -0,0 +1,91 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2017-2018 Max-Planck-Society
+# Author: Jakob Knollmueller
+#
+# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
+
+import numpy as np
+from astropy.io import fits
+from matplotlib import pyplot as plt
+from scipy.ndimage.filters import median_filter
+import starblade as sb
+
+if __name__ == '__main__':
+    #specifying location of the input file:
+    # path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
+    # data = fits.open(path)[1].data
+    path = 'data/frame-i-004874-3-0692.fits'
+    path ='data/check.fits'
+    # data = fits.open(path)[1].data
+    data = fits.open(path)[0].data[1000:,1250:]
+    data -= data.min() - 0.001
+    data = data.clip(min=0.001)
+
+    data_true = data.copy()
+
+    data = np.ndarray.astype(data, float)
+    vmin = np.log(data.min()+0.01)
+    vmax = np.log(data.max())
+
+    local_size = 4
+    for i in range(5):
+        for i in range(data.shape[0]/local_size):
+            for j in range(data.shape[1]/local_size):
+                local_data = data[i*local_size:(1+i)*local_size,j*local_size:(1+j)*local_size]
+                local_data_median = np.median(local_data)
+                local_data_var = local_data.var()
+                local_data = local_data.clip(min=local_data_median - 3*np.sqrt(local_data_var),
+                                             max=local_data_median + 3*np.sqrt(local_data_var))
+                data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size] = local_data
+
+
+    background = np.empty_like(data)
+    crowded = np.zeros_like(data)
+    for i in range(data.shape[0] / local_size):
+        for j in range(data.shape[1] / local_size):
+            local_true_data = data_true[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size]
+            local_data = data[i * local_size:(1 + i) * local_size, j * local_size:(1 + j) * local_size]
+            local_true_var = local_true_data.var()
+            local_var = local_data.var()
+            if 0.8 * np.sqrt(local_true_var) >  np.sqrt(local_var):
+                background[i * local_size:(1 + i) * local_size,
+                j * local_size:(1 + j) * local_size] = 2.5*np.median(local_data)-1.5*local_data.mean()
+                crowded[i * local_size:(1 + i) * local_size,
+                j * local_size:(1 + j) * local_size] = 1.
+            else:
+                background[i * local_size:(1 + i) * local_size,
+                j * local_size:(1 + j) * local_size] = local_data.mean()
+
+    background = median_filter(background, size=(local_size,local_size))
+            # alpha = 1.25
+    # Starblade = sb.build_starblade(data, alpha=alpha)
+    # for i in range(10):
+    #     Starblade = sb.starblade_iteration(Starblade)
+    #
+    #     plotting on logarithmic scale
+    # background += background.min()
+    plt.gray()
+    plt.imsave('diffuse_component.png', np.log(background))#, vmin=vmin, vmax=vmax)
+    plt.imsave('pointlike_component.png', (data_true - background), vmin=vmin, vmax=vmax)
+    plt.imsave('crowded.png',crowded)
+        # plt.figure()
+        # k_lenghts = Starblade.power_spectrum.domain[0].k_lengths
+        # plt.plot(k_lenghts, Starblade.power_spectrum.val)
+        # plt.title('power spectrum')
+        # plt.yscale('log')
+        # plt.xscale('log')
+        # plt.ylabel('power')
+        # plt.xscale('harmonic mode')
+        # plt.savefig('power_spectrum.png')
diff --git a/demos/demo.py b/demos/demo.py
index 1f850b7..ac3db51 100644
--- a/demos/demo.py
+++ b/demos/demo.py
@@ -25,18 +25,26 @@ import starblade as sb
 if __name__ == '__main__':
     #specifying location of the input file:
     path = 'data/hst_05195_01_wfpc2_f702w_pc_sci.fits'
-    data = fits.open(path)[1].data
+    path = 'data/frame-i-004874-3-0692.fits'
+
+    # data = fits.open(path)[1].data
+    data = fits.open(path)[0].data[1000:15000,1250:1750]
+    data -= data.min() - 0.001
+    # data = 1.-plt.imread('data/sdss.png').T[0]
+    # data = fits.open(path)[1].data
+
+    data = data.clip(min=0.0001)
 
-    data = data.clip(min=0.001)
 
     data = np.ndarray.astype(data, float)
-    vmin = np.log(data.min()+0.01)
+    vmin = np.log(data.min()+0.2)
     vmax = np.log(data.max())
+    plt.imsave('data.png', np.log(data),vmin=vmin,vmax=vmax)
 
     alpha = 1.25
     Starblade = sb.build_starblade(data, alpha=alpha)
     for i in range(10):
-        Starblade = sb.starblade_iteration(Starblade)
+        Starblade = sb.starblade_iteration(Starblade, samples=i)
 
         #plotting on logarithmic scale
         plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax)
@@ -48,5 +56,5 @@ if __name__ == '__main__':
         plt.yscale('log')
         plt.xscale('log')
         plt.ylabel('power')
-        plt.xscale('harmonic mode')
+        plt.xlabel('harmonic mode')
         plt.savefig('power_spectrum.png')
diff --git a/starblade/__init__.py b/starblade/__init__.py
index f86f342..8fff660 100644
--- a/starblade/__init__.py
+++ b/starblade/__init__.py
@@ -1,2 +1,4 @@
 from .sugar import (build_starblade, starblade_iteration,
                     build_multi_starblade, multi_starblade_iteration)
+from .starblade_kl import StarbladeKL
+from .starblade_energy import StarbladeEnergy
diff --git a/starblade/starblade_kl.py b/starblade/starblade_kl.py
new file mode 100644
index 0000000..8d44fd6
--- /dev/null
+++ b/starblade/starblade_kl.py
@@ -0,0 +1,60 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2017-2018 Max-Planck-Society
+# Author: Jakob Knollmueller
+#
+# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
+
+from nifty4 import Energy, Field,  DiagonalOperator, InversionEnabler
+from starblade_energy import StarbladeEnergy
+
+class StarbladeKL(Energy):
+
+    def __init__(self, position, samples, parameters):
+        super(StarbladeKL, self).__init__(position=position)
+        self.samples = samples
+        self.parameters = parameters
+        self.energy_list=[]
+        for sample in samples:
+            energy = StarbladeEnergy(position+sample,parameters)
+            self.energy_list.append(energy)
+
+
+    def at(self, position):
+        return self.__class__(position, samples=self.samples, parameters=self.parameters)
+
+    @property
+    def value(self):
+        value = 0.
+        for energy in self.energy_list:
+            value += energy.value
+        value /= len(self.energy_list)
+        return value
+
+    @property
+    def gradient(self):
+        gradient = Field.zeros(self.position.domain)
+        for energy in self.energy_list:
+            gradient += energy.gradient
+        gradient /= len(self.energy_list)
+        return gradient
+
+    @property
+    def curvature(self):
+        curvature = DiagonalOperator(Field.zeros(self.position.domain))
+        for energy in self.energy_list:
+            curvature += energy.curvature
+        curvature *= Field(self.position.domain,val=1./len(self.energy_list))
+        return InversionEnabler(curvature, self.parameters['inverter'])
+
diff --git a/starblade/sugar.py b/starblade/sugar.py
index 7a235c8..f77ccaa 100644
--- a/starblade/sugar.py
+++ b/starblade/sugar.py
@@ -21,9 +21,9 @@ from multiprocessing import Pool
 import nifty4 as ift
 
 from .starblade_energy import StarbladeEnergy
+from .starblade_kl import StarbladeKL
 
-
-def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iterations = 3,
+def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=100, newton_iterations = 3,
                     manual_power_spectrum = None):
     """ Setting up the StarbladeEnergy for the given data and parameters
     Parameters
@@ -69,9 +69,12 @@ def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500, newton_iteratio
                       inverter=inverter, FFT=FFT,
                       newton_iterations=newton_iterations, update_power=update_power)
     Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
+
+
     return Starblade
 
-def starblade_iteration(starblade):
+
+def starblade_iteration(starblade, samples=3):
     """ Performing one Newton minimization step
     Parameters
     ----------
@@ -82,14 +85,19 @@ def starblade_iteration(starblade):
                                             tol_abs_gradnorm=1e-8,
                                             iteration_limit=starblade.newton_iterations)
     minimizer = ift.RelaxedNewton(controller=controller)
-    energy, convergence = minimizer(starblade)
+    sample_list = []
+    for i in range(samples):
+        sample = starblade.curvature.inverse.draw_sample()
+        sample_list.append(sample)
+    if len(sample_list)>0:
+        energy = StarbladeKL(starblade.position, samples=sample_list, parameters=starblade.parameters)
+    else:
+        energy = starblade
+    energy, convergence = minimizer(energy)
     new_position = energy.position
     new_parameters = energy.parameters
-    if energy.update_power:
-        h_space = energy.correlation.domain[0]
-        FFT = energy.FFT
-        binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
-        new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
+    if energy.parameters['update_power']:
+        new_power = update_power(energy)
         # new_power /= (new_power.domain[0].k_lengths+1.)**2
         new_parameters['power_spectrum'] = new_power
 
@@ -143,6 +151,18 @@ def multi_starblade_iteration(MultiStarblade,  processes = 1):
             NewStarblades.append(starblade_iteration(starblade))
     return NewStarblades
 
+def update_power(energy):
+    if isinstance(energy, StarbladeKL):
+        power = 0.
+        for en in energy.energy_list:
+            power += ift.power_analyze(energy.parameters['FFT'].inverse_times(en.s),
+                                                 binbounds=en.parameters['power_spectrum'].domain[0].binbounds)
+        power /= len(energy.energy_list)
+    else:
+        power = ift.power_analyze(energy.FFT.inverse_times(energy.s),
+                                   binbounds=energy.parameters['power_spectrum'].domain[0].binbounds)
+    return power
+
 if __name__ == '__main__':
     pass
 
-- 
GitLab