From c8ff706777c26676ff86989081b24cdfbf04aac1 Mon Sep 17 00:00:00 2001
From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de>
Date: Tue, 10 Apr 2018 14:32:29 +0200
Subject: [PATCH] documentation, cleanup

---
 1d_separation.py     |   2 +-
 demo.py              |  25 ++++++++++
 gui_app.py           |   2 +-
 hubble_separation.py |  12 +++--
 multichannel_demo.py |  28 +++++++++++
 point_separation.py  |  74 -----------------------------
 rgb_separation.py    |  25 ----------
 separation_energy.py |  32 +++++++++++--
 sugar.py             | 111 +++++++++++++++++++++++++++++++++++++++++++
 9 files changed, 203 insertions(+), 108 deletions(-)
 create mode 100644 demo.py
 create mode 100644 multichannel_demo.py
 delete mode 100644 point_separation.py
 delete mode 100644 rgb_separation.py
 create mode 100644 sugar.py

diff --git a/1d_separation.py b/1d_separation.py
index f90f0dd..653b090 100644
--- a/1d_separation.py
+++ b/1d_separation.py
@@ -1,4 +1,4 @@
-from point_separation import build_problem, problem_iteration
+from sugar import build_problem, problem_iteration
 import nifty4 as ift
 import numpy as np
 from matplotlib import rc
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000..1adbbb6
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,25 @@
+from sugar import build_starblade, starblade_iteration
+from matplotlib import pyplot as plt
+from astropy.io import fits
+
+import numpy as np
+
+if __name__ == '__main__':
+    #specifying location of the input file:
+    path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
+    data = fits.open(path)[1].data
+
+    data = data.clip(min=0.001)
+
+    data = np.ndarray.astype(data, float)
+    vmin = np.log(data.min()+0.01)
+    vmax = np.log(data.max())
+
+    alpha = 1.25
+    Starblade = build_starblade(data, alpha=alpha)
+    for i in range(10):
+        Starblade = starblade_iteration(Starblade)
+
+        #plotting on logarithmic scale
+        plt.imsave('diffuse_component.png', Starblade.s.val, vmin=vmin, vmax=vmax)
+        plt.imsave('pointlike_component.png', Starblade.u.val, vmin=vmin, vmax=vmax)
\ No newline at end of file
diff --git a/gui_app.py b/gui_app.py
index 6972990..61aa4bf 100644
--- a/gui_app.py
+++ b/gui_app.py
@@ -2,7 +2,7 @@ import matplotlib
 matplotlib.use('agg')
 # matplotlib.use('module://kivy.garden.matplotlib.backend_kivy')
 
-from point_separation import  build_multi_problem, multi_problem_iteration,load_data
+from sugar import  build_multi_problem, multi_problem_iteration,load_data
 
 from kivy.app import App
 from kivy.uix.widget import Widget
diff --git a/hubble_separation.py b/hubble_separation.py
index 7d3c435..f31c09f 100644
--- a/hubble_separation.py
+++ b/hubble_separation.py
@@ -1,4 +1,4 @@
-from point_separation import build_problem, problem_iteration, load_data
+from sugar import build_problem, problem_iteration, load_data
 from nifty4 import *
 import numpy as np
 from matplotlib import rc
@@ -8,14 +8,20 @@ from matplotlib import pyplot as plt
 from matplotlib.colors import LogNorm
 from mpl_toolkits.axes_grid1 import make_axes_locatable
 from mpl_toolkits.axes_grid1 import AxesGrid
+from astropy.io import fits
+
 
 
 
 np.random.seed(42)
 if __name__ == '__main__':
     path = 'hst_05195_01_wfpc2_f702w_pc_sci.fits'
-    data = load_data(path)
-    alpha = 1.3
+    data = fits.open(path)[1].data
+
+    data = data.clip(min=0.001)
+
+    data = np.ndarray.astype(data, float)
+    alpha = 1.25
 
 
 
diff --git a/multichannel_demo.py b/multichannel_demo.py
new file mode 100644
index 0000000..52ffcd8
--- /dev/null
+++ b/multichannel_demo.py
@@ -0,0 +1,28 @@
+from sugar import build_multi_starblade, multi_starblade_iteration
+from matplotlib import pyplot as plt
+import numpy as np
+
+if __name__ == '__main__':
+
+    # data = plt.imread('10Keso1242a.tif')
+    data = plt.imread('eso1242a.jpg')
+
+    data = data.astype(float)
+    data = data.clip(0.0001)
+    alpha = 1.25
+    MultiStarblade = build_multi_starblade(data, alpha)
+
+    for i in range(2):
+        MultiStarblade = multi_starblade_iteration(MultiStarblade, multiprocessing=True)
+
+        #plotting a three channel RGB image in each iteration
+        diffuse = np.empty_like(data)
+        point = np.empty_like(data)
+        for i  in range(len(MultiStarblade)):
+            diffuse[...,i] = np.exp(MultiStarblade[i].s.val)
+            point[...,i] = np.exp(MultiStarblade[i].u.val)
+
+        plt.imsave('rgb_diffuse.jpg',diffuse/255.)
+        plt.imsave('rgb_point.jpg',point/255.)
+
+
diff --git a/point_separation.py b/point_separation.py
deleted file mode 100644
index 0746292..0000000
--- a/point_separation.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import nifty4 as ift
-import numpy as np
-from matplotlib import pyplot as plt
-from astropy.io import fits
-from separation_energy import SeparationEnergy
-from nifty4.library.nonlinearities import PositiveTanh
-
-def load_data(path):
-
-    if path[-5:] == '.fits':
-        data = fits.open(path)[1].data
-    else:
-        data = plt.imread(path)[:,:,0]
-
-    data = data.clip(min=0.001)
-    data = np.ndarray.astype(data, float)
-    return data
-
-def build_problem(data, alpha):
-    s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
-    h_space = s_space.get_default_codomain()
-    data = ift.Field(s_space,val=data)
-    FFT = ift.FFTOperator(h_space)
-    binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
-    p_space = ift.PowerSpace(h_space, binbounds=binbounds)
-    initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), binbounds=p_space.binbounds)
-    initial_correlation = ift.create_power_operator(h_space, initial_spectrum)
-    initial_x = ift.Field(s_space, val=-1.)
-    alpha = ift.Field(s_space, val=alpha)
-    q = ift.Field(s_space, val=10e-40)
-    pos_tanh = PositiveTanh()
-    ICI = ift.GradientNormController(iteration_limit=500,
-                                     tol_abs_gradnorm=1e-5)
-    inverter = ift.ConjugateGradient(controller=ICI)
-
-    parameters = dict(data=data, correlation=initial_correlation,
-                      alpha=alpha, q=q,
-                      inverter=inverter, FFT=FFT, pos_tanh=pos_tanh)
-    separationEnergy = SeparationEnergy(position=initial_x, parameters=parameters)
-    return separationEnergy
-
-def problem_iteration(energy, iterations=3):
-    controller = ift.GradientNormController(name="test1", tol_abs_gradnorm=0.00000001, iteration_limit=iterations)
-    minimizer = ift.RelaxedNewton(controller=controller)
-    energy, convergence = minimizer(energy)
-    new_position = energy.position
-    h_space = energy.correlation.domain[0]
-    FFT = energy.FFT
-    binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
-    new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
-    new_correlation = ift.create_power_operator(h_space, new_power)
-    new_parameters = energy.parameters
-    new_parameters['correlation'] = new_correlation
-    new_energy = SeparationEnergy(new_position, new_parameters)
-    return new_energy
-
-def build_multi_problem(data, alpha):
-    energy_list = []
-    for i in range(data.shape[-1]):
-        energy = build_problem(data[...,i],alpha)
-        energy_list.append(energy)
-    return energy_list
-
-def multi_problem_iteration(energy_list):
-    new_energy = []
-    for energy in energy_list:
-        new_energy.append(problem_iteration(energy))
-    return new_energy
-
-if __name__ == '__main__':
-    pass
-
-
-
diff --git a/rgb_separation.py b/rgb_separation.py
deleted file mode 100644
index 2eb14e1..0000000
--- a/rgb_separation.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from point_separation import build_multi_problem, multi_problem_iteration
-from matplotlib import pyplot as plt
-import numpy as np
-
-if __name__ == '__main__':
-    # data = plt.imread('eso1242a.jpg')
-    data = plt.imread('10Keso1242a.tif')
-    data = data.astype(float)
-    data = data.clip(0.0001)
-    energy_list = build_multi_problem(data, 1.2)
-
-    for i in range(10):
-        energy_list = multi_problem_iteration(energy_list)
-
-
-        diffuse = np.empty_like(data)
-        point = np.empty_like(data)
-        for i  in range(len(energy_list)):
-            diffuse[...,i] = np.exp(energy_list[i].s.val)
-            point[...,i] = np.exp(energy_list[i].u.val)
-
-        plt.imsave('rgb_diffuse.jpg',diffuse/255.)
-        plt.imsave('rgb_point.jpg',point/255.)
-
-
diff --git a/separation_energy.py b/separation_energy.py
index 76d7810..a1ca6ca 100644
--- a/separation_energy.py
+++ b/separation_energy.py
@@ -1,14 +1,39 @@
 from nifty4 import Energy, Field, log, exp, DiagonalOperator
 from nifty4.library import WienerFilterCurvature
+from nifty4.library.nonlinearities import PositiveTanh
 
 
-class SeparationEnergy(Energy):
+class StarbladeEnergy(Energy):
+    """The Energy for the starblade problem.
+
+    It implements the Information Hamiltonian of the separation of d
+
+    Parameters
+    ----------
+    position : Field
+        The current position of the separation.
+    parameters : Dictionary
+        Dictionary containing all relevant quantities for the inference,
+        data : Field
+            The image data.
+        alpha : Field
+            Slope parameter of the point-source prior
+        q : Field
+            Cutoff parameter of the point-source prior
+        correlation : Field
+            A field in the Fourier space which encodes the diagonal of the prior
+            correlation structure of the diffuse component
+        FFT : FFTOperator
+            An operator performing the Fourier transform
+        inverter : ConjugateGradient
+            the minimization strategy to use for operator inversion
+    """
 
     def __init__(self, position, parameters):
 
         x = position.val.clip(-9, 9)
         position = Field(position.domain, val=x)
-        super(SeparationEnergy, self).__init__(position=position)
+        super(StarbladeEnergy, self).__init__(position=position)
 
         self.parameters = parameters
         self.inverter = parameters['inverter']
@@ -17,8 +42,7 @@ class SeparationEnergy(Energy):
         self.correlation = parameters['correlation']
         self.alpha = parameters['alpha']
         self.q = parameters['q']
-        pos_tanh = parameters['pos_tanh']
-
+        pos_tanh = PositiveTanh()
         self.S = self.FFT * self.correlation * self.FFT.adjoint
         self.a = pos_tanh(self.position)
         self.a_p = pos_tanh.derivative(self.position)
diff --git a/sugar.py b/sugar.py
new file mode 100644
index 0000000..3c758c2
--- /dev/null
+++ b/sugar.py
@@ -0,0 +1,111 @@
+import nifty4 as ift
+import numpy as np
+from matplotlib import pyplot as plt
+from multiprocessing import Pool
+from separation_energy import StarbladeEnergy
+
+
+
+def build_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
+    """ Setting up the StarbladeEnergy for the given data and parameters
+    Parameters
+    ----------
+    data : array
+        The data in a numpy array
+    alpha : float
+        The slope parameter of the point source prior (default: 1.5).
+    q : float
+        The cutoff parameter of the point source prior (default: 1e-40).
+    cg_iterations : int
+        Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
+    """
+
+    s_space = ift.RGSpace(data.shape, distances=len(data.shape) * [1])
+    h_space = s_space.get_default_codomain()
+    data = ift.Field(s_space,val=data)
+    FFT = ift.FFTOperator(h_space, target=s_space)
+    binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic = False)
+    p_space = ift.PowerSpace(h_space, binbounds=binbounds)
+    initial_spectrum = ift.power_analyze(FFT.inverse_times(ift.log(data)), binbounds=p_space.binbounds)
+    initial_spectrum /= (p_space.k_lengths+1.)**2
+    initial_correlation = ift.create_power_operator(h_space, initial_spectrum)
+    initial_x = ift.Field(s_space, val=-1.)
+    alpha = ift.Field(s_space, val=alpha)
+    q = ift.Field(s_space, val=q)
+    ICI = ift.GradientNormController(iteration_limit=cg_iterations,
+                                     tol_abs_gradnorm=1e-5)
+    inverter = ift.ConjugateGradient(controller=ICI)
+
+    parameters = dict(data=data, correlation=initial_correlation,
+                      alpha=alpha, q=q,
+                      inverter=inverter, FFT=FFT)
+    Starblade = StarbladeEnergy(position=initial_x, parameters=parameters)
+    return Starblade
+
+def starblade_iteration(starblade, iterations=3):
+    """ Performing one Newton minimization step
+    Parameters
+    ----------
+    starblade : StarbladeEnergy
+        An instance of an Starblade Energy
+    iterations : int
+        The number of steps with the Newton scheme (default: 3).
+    """
+    controller = ift.GradientNormController(name="Newton", tol_abs_gradnorm=1e-8, iteration_limit=iterations)
+    minimizer = ift.RelaxedNewton(controller=controller)
+    energy, convergence = minimizer(starblade)
+    new_position = energy.position
+    h_space = energy.correlation.domain[0]
+    FFT = energy.FFT
+    binbounds = ift.PowerSpace.useful_binbounds(h_space, logarithmic=False)
+    new_power = ift.power_analyze(FFT.inverse_times(energy.s), binbounds=binbounds)
+    # new_power /= (new_power.domain[0].k_lengths+1.)**2
+
+    new_correlation = ift.create_power_operator(h_space, new_power)
+    new_parameters = energy.parameters
+    # new_parameters['correlation'] = new_correlation
+    NewStarblade = StarbladeEnergy(new_position, new_parameters)
+    return NewStarblade
+
+def build_multi_starblade(data, alpha=1.5, q=1e-40, cg_iterations=500):
+    """ Builds a list of StarbladeEnergies for the given multi-channel dataset
+    Parameters
+    ----------
+    data : array
+        The data in a numpy array of the multi-channel dataset with channel axis data[-1].
+    alpha : float
+        The slope parameter of the point source prior (default: 1.5).
+    q : float
+        The cutoff parameter of the point source prior (default: 1e-40).
+    cg_iterations : int
+        Maximum number of conjugate gradient iterations for numerical operator inversion (default: 500).
+    """
+    MultiStarblade = []
+    for i in range(data.shape[-1]):
+        starblade = build_starblade(data[...,i],alpha=alpha, q=q, cg_iterations=cg_iterations)
+        MultiStarblade.append(starblade)
+    return MultiStarblade
+
+def multi_starblade_iteration(MultiStarblade,  multiprocessing = False):
+    """ Performing one Newton minimization step for all entries of the MultiStarblade list.
+    Parameters
+    ----------
+    MultiStarblade : list of StarbladeEnergy
+        A list of instances of an Starblade Energy
+    iterations : int
+        The number of steps with the Newton scheme (default: 3).
+    """
+    if multiprocessing:
+        NewStarblades = list(Pool(processes=3).map(starblade_iteration,
+                                                           MultiStarblade))
+    else:
+        NewStarblades = []
+        for starblade in MultiStarblade:
+            NewStarblades.append(starblade_iteration(starblade))
+    return NewStarblades
+
+if __name__ == '__main__':
+    pass
+
+
+
-- 
GitLab