From 35e632b032eec6cdb1ad04f2effd052cbb29b441 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Sun, 4 Feb 2018 16:10:59 +0100 Subject: [PATCH] make the 1D part (sort of) work with NIFTy4 --- demos/Wiener Filter.ipynb | 187 ++++++++++++++++---------------------- 1 file changed, 80 insertions(+), 107 deletions(-) diff --git a/demos/Wiener Filter.ipynb b/demos/Wiener Filter.ipynb index eddcb03db..0ad954caa 100644 --- a/demos/Wiener Filter.ipynb +++ b/demos/Wiener Filter.ipynb @@ -93,7 +93,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "-" } @@ -101,7 +100,6 @@ "outputs": [], "source": [ "N_pixels = 512 # Number of pixels\n", - "sigma2 = .5 # Noise variance\n", "\n", "def pow_spec(k):\n", " P0, k0, gamma = [.2, 5, 6]\n", @@ -140,11 +138,11 @@ }, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from nifty import (DiagonalOperator, EndomorphicOperator, FFTOperator, Field,\n", - " InvertibleOperatorMixin, PowerSpace, RGSpace,\n", - " create_power_operator, SmoothingOperator, DiagonalProberMixin, Prober)" + "np.random.seed(42)\n", + "import nifty4 as ift\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" ] }, { @@ -162,43 +160,20 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ - "class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):\n", - " def __init__(self, R, N, Sh, default_spaces=None):\n", - " super(PropagatorOperator, self).__init__(default_spaces=default_spaces,\n", - " preconditioner=lambda x : fft.adjoint_times(Sh.times(fft.times(x))))\n", - "\n", - " self.R = R\n", - " self.N = N\n", - " self.Sh = Sh\n", - " self._domain = R.domain\n", - " self.fft = FFTOperator(domain=R.domain, target=Sh.domain)\n", - "\n", - " def _inverse_times(self, x, spaces, x0=None):\n", - " return self.R.adjoint_times(self.N.inverse_times(self.R(x))) \\\n", - " + self.fft.adjoint_times(self.Sh.inverse_times(self.fft(x)))\n", - "\n", - " @property\n", - " def domain(self):\n", - " return self._domain\n", - "\n", - " @property\n", - " def unitary(self):\n", - " return False\n", - "\n", - " @property\n", - " def symmetric(self):\n", - " return False\n", - "\n", - " @property\n", - " def self_adjoint(self):\n", - " return True" + "def PropagatorOperator(R, N, Sh):\n", + " IC = ift.GradientNormController(name=\"inverter\", iteration_limit=50000,\n", + " tol_abs_gradnorm=0.1)\n", + " inverter = ift.ConjugateGradient(controller=IC)\n", + " D = (R.adjoint*N.inverse*R + Sh.inverse).inverse\n", + " # MR FIXME: we can/should provide a preconditioner here as well!\n", + " return ift.InversionEnabler(D, inverter)\n", + " #return ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter).inverse\n" ] }, { @@ -247,30 +222,33 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ - "s_space = RGSpace(N_pixels)\n", - "fft = FFTOperator(s_space)\n", - "h_space = fft.target[0]\n", - "p_space = PowerSpace(h_space)\n", - "\n", + "s_space = ift.RGSpace(N_pixels)\n", + "h_space = s_space.get_default_codomain()\n", + "HT = ift.HarmonicTransformOperator(h_space, target=s_space)\n", + "p_space = ift.PowerSpace(h_space)\n", "\n", "# Operators\n", - "Sh = create_power_operator(h_space, power_spectrum=pow_spec)\n", - "N = DiagonalOperator(s_space, diagonal=sigma2, bare=True)\n", - "R = DiagonalOperator(s_space, diagonal=1.)\n", - "D = PropagatorOperator(R=R, N=N, Sh=Sh)\n", + "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n", + "R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)\n", "\n", "# Fields and data\n", - "sh = Field(p_space, val=pow_spec).power_synthesize(real_signal=True)\n", - "s = fft.adjoint_times(sh)\n", - "n = Field.from_random(domain=s_space, random_type='normal',\n", - " std=np.sqrt(sigma2), mean=0)\n", - "d = R(s) + n\n", - "j = R.adjoint_times(N.inverse_times(d))" + "sh = ift.power_synthesize(ift.PS_field(p_space, pow_spec),real_signal=True)\n", + "noiseless_data=R(sh)\n", + "signal_to_noise = 5\n", + "noise_amplitude = noiseless_data.std()/signal_to_noise\n", + "N = ift.ScalingOperator(noise_amplitude**2, s_space)\n", + "\n", + "n = ift.Field.from_random(domain=s_space, random_type='normal',\n", + " std=noise_amplitude, mean=0)\n", + "ift.plot(n)\n", + "d = noiseless_data + n\n", + "ift.plot(d)\n", + "j = R.adjoint_times(N.inverse_times(d))\n", + "ift.plot(HT(j))\n", + "D = PropagatorOperator(R=R, N=N, Sh=Sh)" ] }, { @@ -288,7 +266,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "-" } @@ -313,23 +290,22 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ - "s_power = sh.power_analyze()\n", - "m_power = fft(m).power_analyze()\n", - "s_power_data = s_power.val.get_full_data().real\n", - "m_power_data = m_power.val.get_full_data().real\n", + "s_power = ift.power_analyze(sh)\n", + "m_power = ift.power_analyze(m)\n", + "s_power_data = s_power.val.real\n", + "m_power_data = m_power.val.real\n", "\n", "# Get signal data and reconstruction data\n", - "s_data = s.val.get_full_data().real\n", - "m_data = m.val.get_full_data().real\n", + "s_data = HT(sh).val.real\n", + "m_data = HT(m).val.real\n", "\n", - "d_data = d.val.get_full_data().real" + "d_data = d.val.real" ] }, { @@ -375,7 +351,7 @@ "plt.plot(s_data - s_data, 'k', label=\"Signal\", alpha=.5, linewidth=.5)\n", "plt.plot(d_data - s_data, 'k+', label=\"Data\")\n", "plt.plot(m_data - s_data, 'r', label=\"Reconstruction\")\n", - "plt.axhspan(-np.sqrt(sigma2),np.sqrt(sigma2), facecolor='0.9', alpha=.5)\n", + "plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)\n", "plt.title(\"Residuals\")\n", "plt.legend()\n", "plt.show()" @@ -410,8 +386,8 @@ "plt.plot(xs, pow_spec(xs), label=\"True Power Spectrum\", linewidth=.7, color='k')\n", "plt.plot(s_power_data, 'k', label=\"Signal\", alpha=.5, linewidth=.5)\n", "plt.plot(m_power_data, 'r', label=\"Reconstruction\")\n", - "plt.axhline(sigma2 / N_pixels, color=\"k\", linestyle='--', label=\"Noise level\", alpha=.5)\n", - "plt.axhspan(sigma2 / N_pixels, ymin, facecolor='0.9', alpha=.5)\n", + "plt.axhline(noise_amplitude**2 / N_pixels, color=\"k\", linestyle='--', label=\"Noise level\", alpha=.5)\n", + "plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)\n", "plt.title(\"Power Spectrum\")\n", "plt.legend()\n", "plt.show()" @@ -432,7 +408,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "skip" } @@ -440,15 +415,15 @@ "outputs": [], "source": [ "# Operators\n", - "Sh = create_power_operator(h_space, power_spectrum=pow_spec)\n", - "N = DiagonalOperator(s_space, diagonal=sigma2, bare=True)\n", + "Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)\n", + "N = ift.ScalingOperator(noise_amplitude**2,s_space)\n", "# R is defined below\n", "\n", "# Fields\n", - "sh = Field(p_space, val=pow_spec).power_synthesize(real_signal=True)\n", - "s = fft.adjoint_times(sh)\n", - "n = Field.from_random(domain=s_space, random_type='normal',\n", - " std=np.sqrt(sigma2), mean=0)" + "sh = ift.power_synthesize(ift.PS_field(p_space,pow_spec),real_signal=True)\n", + "s = HT(sh)\n", + "n = ift.Field.from_random(domain=s_space, random_type='normal',\n", + " std=noise_amplitude, mean=0)" ] }, { @@ -466,7 +441,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "-" } @@ -474,22 +448,21 @@ "outputs": [], "source": [ "l = int(N_pixels * 0.2)\n", - "h = int(N_pixels * 0.2 * 4)\n", + "h = int(N_pixels * 0.2 * 2)\n", "\n", - "mask = Field(s_space, val=1)\n", + "mask = ift.Field(s_space, val=1)\n", "mask.val[ l : h] = 0\n", "\n", - "R = DiagonalOperator(s_space, diagonal = mask)\n", + "R = ift.DiagonalOperator(mask)*HT\n", "n.val[l:h] = 0\n", "\n", - "d = R(s) + n" + "d = R(sh) + n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "skip" } @@ -516,17 +489,21 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true + "scrolled": true }, "outputs": [], "source": [ - "class DiagonalProber(DiagonalProberMixin, Prober):\n", - " def __init__(self, *args, **kwargs):\n", - " super(DiagonalProber,self).__init__(*args, **kwargs)\n", + "sc = ift.probing.utils.StatCalculator()\n", "\n", - "diagProber = DiagonalProber(domain=s_space, probe_dtype=np.complex, probe_count=200)\n", - "diagProber(D)\n", - "m_var = Field(s_space,val=diagProber.diagonal.val).weight(-1)" + "IC = ift.GradientNormController(name=\"inverter\", iteration_limit=50000,\n", + " tol_abs_gradnorm=0.1)\n", + "inverter = ift.ConjugateGradient(controller=IC)\n", + "curv = ift.library.wiener_filter_curvature.WienerFilterCurvature(R,N,Sh,inverter)\n", + "\n", + "for i in range(200):\n", + " sc.add(HT(curv.generate_posterior_sample()))\n", + "\n", + "m_var = sc.var" ] }, { @@ -544,25 +521,24 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ - "s_power = sh.power_analyze()\n", - "m_power = fft(m).power_analyze()\n", - "s_power_data = s_power.val.get_full_data().real\n", - "m_power_data = m_power.val.get_full_data().real\n", + "s_power = ift.power_analyze(sh)\n", + "m_power = ift.power_analyze(m)\n", + "s_power_data = s_power.val.real\n", + "m_power_data = m_power.val.real\n", "\n", "# Get signal data and reconstruction data\n", - "s_data = s.val.get_full_data().real\n", - "m_data = m.val.get_full_data().real\n", - "m_var_data = m_var.val.get_full_data().real\n", + "s_data = s.val.real\n", + "m_data = HT(m).val.real\n", + "m_var_data = m_var.val.real\n", "uncertainty = np.sqrt(np.abs(m_var_data))\n", - "\n", - "d_data = d.val.get_full_data().real\n", + "ift.plot(ift.sqrt(m_var))\n", + "d_data = d.val.real\n", "\n", "# Set lost data to NaN for proper plotting\n", "d_data[d_data == 0] = np.nan" @@ -646,9 +622,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "N_pixels = 256 # Number of pixels\n", @@ -660,14 +634,13 @@ " return P0 * (1. + (k/k0)**2)**(- gamma / 2)\n", "\n", "\n", - "s_space = RGSpace([N_pixels, N_pixels])" + "s_space = ift.RGSpace([N_pixels, N_pixels])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "slideshow": { "slide_type": "skip" } @@ -857,21 +830,21 @@ "metadata": { "celltoolbar": "Slideshow", "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 2", "language": "python", - "name": "python3" + "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" + "pygments_lexer": "ipython2", + "version": "2.7.12" } }, "nbformat": 4, -- GitLab