From 35e632b032eec6cdb1ad04f2effd052cbb29b441 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sun, 4 Feb 2018 16:10:59 +0100
Subject: [PATCH] make the 1D part (sort of) work with NIFTy4

---
 demos/Wiener Filter.ipynb | 187 ++++++++++++++++----------------------
 1 file changed, 80 insertions(+), 107 deletions(-)

diff --git a/demos/Wiener Filter.ipynb b/demos/Wiener Filter.ipynb
index eddcb03db..0ad954caa 100644
--- a/demos/Wiener Filter.ipynb	
+++ b/demos/Wiener Filter.ipynb	
@@ -93,7 +93,6 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "-"
     }
@@ -101,7 +100,6 @@
    "outputs": [],
    "source": [
     "N_pixels = 512     # Number of pixels\n",
-    "sigma2 = .5        # Noise variance\n",
     "\n",
     "def pow_spec(k):\n",
     "    P0, k0, gamma = [.2, 5, 6]\n",
@@ -140,11 +138,11 @@
    },
    "outputs": [],
    "source": [
-    "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "from nifty import (DiagonalOperator, EndomorphicOperator, FFTOperator, Field,\n",
-    "                   InvertibleOperatorMixin, PowerSpace, RGSpace,\n",
-    "                   create_power_operator, SmoothingOperator, DiagonalProberMixin, Prober)"
+    "np.random.seed(42)\n",
+    "import nifty4 as ift\n",
+    "import matplotlib.pyplot as plt\n",
+    "%matplotlib inline"
    ]
   },
   {
@@ -162,43 +160,20 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "-"
     }
    },
    "outputs": [],
    "source": [
-    "class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):\n",
-    "    def __init__(self, R, N, Sh, default_spaces=None):\n",
-    "        super(PropagatorOperator, self).__init__(default_spaces=default_spaces,\n",
-    "                                                 preconditioner=lambda x : fft.adjoint_times(Sh.times(fft.times(x))))\n",
-    "\n",
-    "        self.R = R\n",
-    "        self.N = N\n",
-    "        self.Sh = Sh\n",
-    "        self._domain = R.domain\n",
-    "        self.fft = FFTOperator(domain=R.domain, target=Sh.domain)\n",
-    "\n",
-    "    def _inverse_times(self, x, spaces, x0=None):\n",
-    "        return self.R.adjoint_times(self.N.inverse_times(self.R(x))) \\\n",
-    "               + self.fft.adjoint_times(self.Sh.inverse_times(self.fft(x)))\n",
-    "\n",
-    "    @property\n",
-    "    def domain(self):\n",
-    "        return self._domain\n",
-    "\n",
-    "    @property\n",
-    "    def unitary(self):\n",
-    "        return False\n",
-    "\n",
-    "    @property\n",
-    "    def symmetric(self):\n",
-    "        return False\n",
-    "\n",
-    "    @property\n",
-    "    def self_adjoint(self):\n",
-    "        return True"
+    "def PropagatorOperator(R, N, Sh):\n",
+    "    IC = ift.GradientNormController(name=\"inverter\", iteration_limit=50000,\n",
+    "                                    tol_abs_gradnorm=0.1)\n",
+    "    inverter = ift.ConjugateGradient(controller=IC)\n",
+    "    D = (R.adjoint*N.inverse*R + Sh.inverse).inverse\n",
+    "    # MR FIXME: we can/should provide a preconditioner here as well!\n",
+    "    return ift.InversionEnabler(D, inverter)\n",
+    "    #return ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter).inverse\n"
    ]
   },
   {
@@ -247,30 +222,33 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {
-    "collapsed": true
-   },
+   "metadata": {},
    "outputs": [],
    "source": [
-    "s_space = RGSpace(N_pixels)\n",
-    "fft = FFTOperator(s_space)\n",
-    "h_space = fft.target[0]\n",
-    "p_space = PowerSpace(h_space)\n",
-    "\n",
+    "s_space = ift.RGSpace(N_pixels)\n",
+    "h_space = s_space.get_default_codomain()\n",
+    "HT = ift.HarmonicTransformOperator(h_space, target=s_space)\n",
+    "p_space = ift.PowerSpace(h_space)\n",
     "\n",
     "# Operators\n",
-    "Sh = create_power_operator(h_space, power_spectrum=pow_spec)\n",
-    "N = DiagonalOperator(s_space, diagonal=sigma2, bare=True)\n",
-    "R = DiagonalOperator(s_space, diagonal=1.)\n",
-    "D = PropagatorOperator(R=R, N=N, Sh=Sh)\n",
+    "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n",
+    "R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)\n",
     "\n",
     "# Fields and data\n",
-    "sh = Field(p_space, val=pow_spec).power_synthesize(real_signal=True)\n",
-    "s = fft.adjoint_times(sh)\n",
-    "n = Field.from_random(domain=s_space, random_type='normal',\n",
-    "                      std=np.sqrt(sigma2), mean=0)\n",
-    "d = R(s) + n\n",
-    "j = R.adjoint_times(N.inverse_times(d))"
+    "sh = ift.power_synthesize(ift.PS_field(p_space, pow_spec),real_signal=True)\n",
+    "noiseless_data=R(sh)\n",
+    "signal_to_noise = 5\n",
+    "noise_amplitude = noiseless_data.std()/signal_to_noise\n",
+    "N = ift.ScalingOperator(noise_amplitude**2, s_space)\n",
+    "\n",
+    "n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
+    "                      std=noise_amplitude, mean=0)\n",
+    "ift.plot(n)\n",
+    "d = noiseless_data + n\n",
+    "ift.plot(d)\n",
+    "j = R.adjoint_times(N.inverse_times(d))\n",
+    "ift.plot(HT(j))\n",
+    "D = PropagatorOperator(R=R, N=N, Sh=Sh)"
    ]
   },
   {
@@ -288,7 +266,6 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "-"
     }
@@ -313,23 +290,22 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "-"
     }
    },
    "outputs": [],
    "source": [
-    "s_power = sh.power_analyze()\n",
-    "m_power = fft(m).power_analyze()\n",
-    "s_power_data = s_power.val.get_full_data().real\n",
-    "m_power_data = m_power.val.get_full_data().real\n",
+    "s_power = ift.power_analyze(sh)\n",
+    "m_power = ift.power_analyze(m)\n",
+    "s_power_data = s_power.val.real\n",
+    "m_power_data = m_power.val.real\n",
     "\n",
     "# Get signal data and reconstruction data\n",
-    "s_data = s.val.get_full_data().real\n",
-    "m_data = m.val.get_full_data().real\n",
+    "s_data = HT(sh).val.real\n",
+    "m_data = HT(m).val.real\n",
     "\n",
-    "d_data = d.val.get_full_data().real"
+    "d_data = d.val.real"
    ]
   },
   {
@@ -375,7 +351,7 @@
     "plt.plot(s_data - s_data, 'k', label=\"Signal\", alpha=.5, linewidth=.5)\n",
     "plt.plot(d_data - s_data, 'k+', label=\"Data\")\n",
     "plt.plot(m_data - s_data, 'r', label=\"Reconstruction\")\n",
-    "plt.axhspan(-np.sqrt(sigma2),np.sqrt(sigma2), facecolor='0.9', alpha=.5)\n",
+    "plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)\n",
     "plt.title(\"Residuals\")\n",
     "plt.legend()\n",
     "plt.show()"
@@ -410,8 +386,8 @@
     "plt.plot(xs, pow_spec(xs), label=\"True Power Spectrum\", linewidth=.7, color='k')\n",
     "plt.plot(s_power_data, 'k', label=\"Signal\", alpha=.5, linewidth=.5)\n",
     "plt.plot(m_power_data, 'r', label=\"Reconstruction\")\n",
-    "plt.axhline(sigma2 / N_pixels, color=\"k\", linestyle='--', label=\"Noise level\", alpha=.5)\n",
-    "plt.axhspan(sigma2 / N_pixels, ymin, facecolor='0.9', alpha=.5)\n",
+    "plt.axhline(noise_amplitude**2 / N_pixels, color=\"k\", linestyle='--', label=\"Noise level\", alpha=.5)\n",
+    "plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)\n",
     "plt.title(\"Power Spectrum\")\n",
     "plt.legend()\n",
     "plt.show()"
@@ -432,7 +408,6 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "skip"
     }
@@ -440,15 +415,15 @@
    "outputs": [],
    "source": [
     "# Operators\n",
-    "Sh = create_power_operator(h_space, power_spectrum=pow_spec)\n",
-    "N = DiagonalOperator(s_space, diagonal=sigma2, bare=True)\n",
+    "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n",
+    "N = ift.ScalingOperator(noise_amplitude**2,s_space)\n",
     "# R is defined below\n",
     "\n",
     "# Fields\n",
-    "sh = Field(p_space, val=pow_spec).power_synthesize(real_signal=True)\n",
-    "s = fft.adjoint_times(sh)\n",
-    "n = Field.from_random(domain=s_space, random_type='normal',\n",
-    "                      std=np.sqrt(sigma2), mean=0)"
+    "sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True)\n",
+    "s = HT(sh)\n",
+    "n = ift.Field.from_random(domain=s_space, random_type='normal',\n",
+    "                      std=noise_amplitude, mean=0)"
    ]
   },
   {
@@ -466,7 +441,6 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "-"
     }
@@ -474,22 +448,21 @@
    "outputs": [],
    "source": [
     "l = int(N_pixels * 0.2)\n",
-    "h = int(N_pixels * 0.2 * 4)\n",
+    "h = int(N_pixels * 0.2 * 2)\n",
     "\n",
-    "mask = Field(s_space, val=1)\n",
+    "mask = ift.Field(s_space, val=1)\n",
     "mask.val[ l : h] = 0\n",
     "\n",
-    "R = DiagonalOperator(s_space, diagonal = mask)\n",
+    "R = ift.DiagonalOperator(mask)*HT\n",
     "n.val[l:h] = 0\n",
     "\n",
-    "d = R(s) + n"
+    "d = R(sh) + n"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "skip"
     }
@@ -516,17 +489,21 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true
+    "scrolled": true
    },
    "outputs": [],
    "source": [
-    "class DiagonalProber(DiagonalProberMixin, Prober):\n",
-    "    def __init__(self, *args, **kwargs):\n",
-    "        super(DiagonalProber,self).__init__(*args, **kwargs)\n",
+    "sc = ift.probing.utils.StatCalculator()\n",
     "\n",
-    "diagProber = DiagonalProber(domain=s_space, probe_dtype=np.complex, probe_count=200)\n",
-    "diagProber(D)\n",
-    "m_var = Field(s_space,val=diagProber.diagonal.val).weight(-1)"
+    "IC = ift.GradientNormController(name=\"inverter\", iteration_limit=50000,\n",
+    "                                    tol_abs_gradnorm=0.1)\n",
+    "inverter = ift.ConjugateGradient(controller=IC)\n",
+    "curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)\n",
+    "\n",
+    "for i in range(200):\n",
+    "    sc.add(HT(curv.generate_posterior_sample()))\n",
+    "\n",
+    "m_var = sc.var"
    ]
   },
   {
@@ -544,25 +521,24 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "skip"
     }
    },
    "outputs": [],
    "source": [
-    "s_power = sh.power_analyze()\n",
-    "m_power = fft(m).power_analyze()\n",
-    "s_power_data = s_power.val.get_full_data().real\n",
-    "m_power_data = m_power.val.get_full_data().real\n",
+    "s_power = ift.power_analyze(sh)\n",
+    "m_power = ift.power_analyze(m)\n",
+    "s_power_data = s_power.val.real\n",
+    "m_power_data = m_power.val.real\n",
     "\n",
     "# Get signal data and reconstruction data\n",
-    "s_data = s.val.get_full_data().real\n",
-    "m_data = m.val.get_full_data().real\n",
-    "m_var_data = m_var.val.get_full_data().real\n",
+    "s_data = s.val.real\n",
+    "m_data = HT(m).val.real\n",
+    "m_var_data = m_var.val.real\n",
     "uncertainty = np.sqrt(np.abs(m_var_data))\n",
-    "\n",
-    "d_data = d.val.get_full_data().real\n",
+    "ift.plot(ift.sqrt(m_var))\n",
+    "d_data = d.val.real\n",
     "\n",
     "# Set lost data to NaN for proper plotting\n",
     "d_data[d_data == 0] = np.nan"
@@ -646,9 +622,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {
-    "collapsed": true
-   },
+   "metadata": {},
    "outputs": [],
    "source": [
     "N_pixels = 256      # Number of pixels\n",
@@ -660,14 +634,13 @@
     "    return P0 * (1. + (k/k0)**2)**(- gamma / 2)\n",
     "\n",
     "\n",
-    "s_space = RGSpace([N_pixels, N_pixels])"
+    "s_space = ift.RGSpace([N_pixels, N_pixels])"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "collapsed": true,
     "slideshow": {
      "slide_type": "skip"
     }
@@ -857,21 +830,21 @@
  "metadata": {
   "celltoolbar": "Slideshow",
   "kernelspec": {
-   "display_name": "Python 3",
+   "display_name": "Python 2",
    "language": "python",
-   "name": "python3"
+   "name": "python2"
   },
   "language_info": {
    "codemirror_mode": {
     "name": "ipython",
-    "version": 3
+    "version": 2
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.6.4"
+   "pygments_lexer": "ipython2",
+   "version": "2.7.12"
   }
  },
  "nbformat": 4,
-- 
GitLab