diff --git a/demos/wiener_filter_via_curvature.py b/demos/wiener_filter_via_curvature.py index 0224333321b1bf1621bc0c10ae3a759d9d0fb0f6..d3b826cd7c026f019f6101755db3cc3d9b77f781 100644 --- a/demos/wiener_filter_via_curvature.py +++ b/demos/wiener_filter_via_curvature.py @@ -13,11 +13,11 @@ if __name__ == "__main__": # Setting up variable parameters # Typical distance over which the field is correlated - correlation_length = 0.01 + correlation_length = 0.05 # Variance of field in position space sqrt(<|s_x|^2>) field_variance = 2. # smoothing length of response (in same unit as L) - response_sigma = 0.1 + response_sigma = 0.01 # The signal to noise ratio signal_to_noise = 0.7 @@ -73,9 +73,10 @@ if __name__ == "__main__": m_s = fft(m) plotter = plotting.RG2DPlotter() - plotter.title = 'mock_signal.html'; + plotter.path = 'mock_signal.html' plotter(mock_signal) - plotter.title = 'data.html' + plotter.path = 'data.html' plotter(Field(signal_space, val=data.val.get_full_data().reshape(signal_space.shape))) - plotter.title = 'map.html'; plotter(m_s) \ No newline at end of file + plotter.path = 'map.html' + plotter(m_s) diff --git a/demos/wiener_filter_via_curvature_real.py b/demos/wiener_filter_via_curvature_real.py new file mode 100644 index 0000000000000000000000000000000000000000..4daed2cf8c8da41a5634b103fc74f18e14e79a20 --- /dev/null +++ b/demos/wiener_filter_via_curvature_real.py @@ -0,0 +1,82 @@ +import numpy as np + +from nifty import RGSpace, PowerSpace, Field, RealFFTOperator,\ + ComposedOperator, DiagonalOperator, ResponseOperator,\ + plotting, create_power_operator +from nifty.library import WienerFilterCurvature + + +if __name__ == "__main__": + + distribution_strategy = 'not' + + # Setting up variable parameters + + # Typical distance over which the field is correlated + correlation_length = 0.05 + # Variance of field in position space sqrt(<|s_x|^2>) + field_variance = 2. + # smoothing length of response (in same unit as L) + response_sigma = 0.01 + # The signal to noise ratio + signal_to_noise = 0.7 + + # note that field_variance**2 = a*k_0/4. for this analytic form of power + # spectrum + def power_spectrum(k): + a = 4 * correlation_length * field_variance**2 + return a / (1 + k * correlation_length) ** 4 + + # Setting up the geometry + + # Total side-length of the domain + L = 2. + # Grid resolution (pixels per axis) + N_pixels = 512 + + signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels) + harmonic_space = RealFFTOperator.get_default_codomain(signal_space) + fft = RealFFTOperator(harmonic_space, target=signal_space) + power_space = PowerSpace(harmonic_space, + distribution_strategy=distribution_strategy) + + # Creating the mock data + S = create_power_operator(harmonic_space, power_spectrum=power_spectrum, + distribution_strategy=distribution_strategy) + + mock_power = Field(power_space, val=power_spectrum, + distribution_strategy=distribution_strategy) + np.random.seed(43) + mock_harmonic = mock_power.power_synthesize(real_signal=True) + mock_harmonic = mock_harmonic.real + mock_harmonic.imag + mock_signal = fft(mock_harmonic) + + R = ResponseOperator(signal_space, sigma=(response_sigma,)) + data_domain = R.target[0] + R_harmonic = ComposedOperator([fft, R], default_spaces=[0, 0]) + + N = DiagonalOperator(data_domain, + diagonal=mock_signal.var()/signal_to_noise, + bare=True) + noise = Field.from_random(domain=data_domain, + random_type='normal', + std=mock_signal.std()/np.sqrt(signal_to_noise), + mean=0) + data = R(mock_signal) + noise + + # Wiener filter + + j = R_harmonic.adjoint_times(N.inverse_times(data)) + wiener_curvature = WienerFilterCurvature(S=S, N=N, R=R_harmonic) + + m = wiener_curvature.inverse_times(j) + m_s = fft(m) + + plotter = plotting.RG2DPlotter() + plotter.path = 'mock_signal.html' + plotter(mock_signal) + plotter.path = 'data.html' + plotter(Field(signal_space, + val=data.val.get_full_data().reshape(signal_space.shape))) + plotter.path = 'map.html' + plotter(m_s) diff --git a/nifty/operators/fft_operator/__init__.py b/nifty/operators/fft_operator/__init__.py index c0247fdece024338baf35fc145f7fb27b0944f7a..53b620e85eddbef6643bc12b5de7ed19fe52cdcf 100644 --- a/nifty/operators/fft_operator/__init__.py +++ b/nifty/operators/fft_operator/__init__.py @@ -17,4 +17,5 @@ # and financially supported by the Studienstiftung des deutschen Volkes. from transformations import * -from fft_operator import FFTOperator +from .fft_operator import FFTOperator +from .real_fft_operator import RealFFTOperator diff --git a/nifty/operators/fft_operator/real_fft_operator.py b/nifty/operators/fft_operator/real_fft_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..e3413c1b875c4d8615393d68d8112d471aeb1fa6 --- /dev/null +++ b/nifty/operators/fft_operator/real_fft_operator.py @@ -0,0 +1,254 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Copyright(C) 2013-2017 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +import numpy as np + +import nifty.nifty_utilities as utilities +from nifty.spaces import RGSpace,\ + GLSpace,\ + HPSpace,\ + LMSpace + +from nifty.operators.linear_operator import LinearOperator +from transformations import RGRGTransformation,\ + LMGLTransformation,\ + LMHPTransformation,\ + GLLMTransformation,\ + HPLMTransformation,\ + TransformationCache + + +class RealFFTOperator(LinearOperator): + """Transforms between a pair of position and harmonic domains. + + Built-in domain pairs are + - a harmonic and a non-harmonic RGSpace (with matching distances) + - a HPSpace and a LMSpace + - a GLSpace and a LMSpace + Within a domain pair, both orderings are possible. + + The operator provides a "times" and an "adjoint_times" operation. + For a pair of RGSpaces, the "adjoint_times" operation is equivalent to + "inverse_times"; for the sphere-related domains this is not the case, since + the operator matrix is not square. + + In contrast to the FFTOperator, RealFFTOperator accepts and returns + real-valued fields only. For the harmonic-space counterpart of a + real-valued field living on an RGSpace, the sum of real + and imaginary components is stored. Since the full complex field has + Hermitian symmetry, this is sufficient to reconstruct the full field + whenever needed (e.g. during the transform back to position space). + + Parameters + ---------- + domain: Space or single-element tuple of Spaces + The domain of the data that is input by "times" and output by + "adjoint_times". + target: Space or single-element tuple of Spaces (optional) + The domain of the data that is output by "times" and input by + "adjoint_times". + If omitted, a co-domain will be chosen automatically. + Whenever "domain" is an RGSpace, the codomain (and its parameters) are + uniquely determined. + For GLSpace, HPSpace, and LMSpace, a sensible (but not unique) + co-domain is chosen that should work satisfactorily in most situations, + but for full control, the user should explicitly specify a codomain. + module: String (optional) + Software module employed for carrying out the transform operations. + For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is + always available (using pyfftw if available, else numpy.fft), and "mpi" + requires pyfftw and offers MPI parallelization. + For sphere-related domains, only "pyHealpix" is + available. If omitted, "fftw" is selected for RGSpaces if available, + else "numpy"; on the sphere the default is "pyHealpix". + + Attributes + ---------- + domain: Tuple of Spaces (with one entry) + The domain of the data that is input by "times" and output by + "adjoint_times". + target: Tuple of Spaces (with one entry) + The domain of the data that is output by "times" and input by + "adjoint_times". + unitary: bool + Returns True if the operator is unitary (currently only the case if + the domain and codomain are RGSpaces), else False. + + Raises + ------ + ValueError: + if "domain" or "target" are not of the proper type. + + """ + default_codomain_dictionary = {RGSpace: RGSpace, + HPSpace: LMSpace, + GLSpace: LMSpace, + LMSpace: GLSpace, + } + + transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation, + (HPSpace, LMSpace): HPLMTransformation, + (GLSpace, LMSpace): GLLMTransformation, + (LMSpace, HPSpace): LMHPTransformation, + (LMSpace, GLSpace): LMGLTransformation + } + + def __init__(self, domain, target=None, module=None, + default_spaces=None): + super(RealFFTOperator, self).__init__(default_spaces) + + # Initialize domain and target + self._domain = self._parse_domain(domain) + if len(self.domain) != 1: + raise ValueError("TransformationOperator accepts only exactly one " + "space as input domain.") + + if target is None: + target = (self.get_default_codomain(self.domain[0]), ) + self._target = self._parse_domain(target) + if len(self.target) != 1: + raise ValueError("TransformationOperator accepts only exactly one " + "space as output target.") + + # Create transformation instances + forward_class = self.transformation_dictionary[ + (self.domain[0].__class__, self.target[0].__class__)] + backward_class = self.transformation_dictionary[ + (self.target[0].__class__, self.domain[0].__class__)] + + self._forward_transformation = TransformationCache.create( + forward_class, self.domain[0], self.target[0], module=module) + + self._backward_transformation = TransformationCache.create( + backward_class, self.target[0], self.domain[0], module=module) + + def _prep(self, x, spaces, dom): + assert issubclass(x.dtype.type,np.floating), \ + "Argument must be real-valued" + spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) + if spaces is None: + # this case means that x lives on only one space, which is + # identical to the space in the domain of `self`. Otherwise the + # input check of LinearOperator would have failed. + axes = x.domain_axes[0] + else: + axes = x.domain_axes[spaces[0]] + + if spaces is None: + result_domain = dom + else: + result_domain = list(x.domain) + result_domain[spaces[0]] = dom[0] + + result_field = x.copy_empty(domain=result_domain, dtype=x.dtype) + return spaces, axes, result_field + + def _times(self, x, spaces): + spaces, axes, result_field = self._prep(x, spaces, self.target) + + if type(self._domain[0]) != RGSpace: + new_val = self._forward_transformation.transform(x.val, axes=axes) + result_field.set_val(new_val=new_val, copy=True) + return result_field + + if self._target[0].harmonic: # going to harmonic space + new_val = self._forward_transformation.transform(x.val, axes=axes) + result_field.set_val(new_val=new_val.real+new_val.imag) + else: + tval = self._domain[0].hermitianize_inverter(x.val, axes) + tval = 0.5*((x.val+tval)+1j*(x.val-tval)) + new_val = self._forward_transformation.transform(tval, axes=axes) + result_field.set_val(new_val=new_val.real) + return result_field + + def _adjoint_times(self, x, spaces): + spaces, axes, result_field = self._prep(x, spaces, self.domain) + + if type(self._domain[0]) != RGSpace: + new_val = self._backward_transformation.transform(x.val, axes=axes) + result_field.set_val(new_val=new_val, copy=True) + return result_field + + if self._domain[0].harmonic: # going to harmonic space + new_val = self._backward_transformation.transform(x.val, axes=axes) + result_field.set_val(new_val=new_val.real+new_val.imag) + else: + tval = self._target[0].hermitianize_inverter(x.val, axes) + tval = 0.5*((x.val+tval)+1j*(x.val-tval)) + new_val = self._backward_transformation.transform(tval, axes=axes) + result_field.set_val(new_val=new_val.real) + return result_field + + # ---Mandatory properties and methods--- + + @property + def domain(self): + return self._domain + + @property + def target(self): + return self._target + + @property + def unitary(self): + return (self._forward_transformation.unitary and + self._backward_transformation.unitary) + + # ---Added properties and methods--- + + @classmethod + def get_default_codomain(cls, domain): + """Returns a codomain to the given domain. + + Parameters + ---------- + domain: Space + An instance of RGSpace, HPSpace, GLSpace or LMSpace. + + Returns + ------- + target: Space + A (more or less perfect) counterpart to "domain" with respect + to a FFT operation. + Whenever "domain" is an RGSpace, the codomain (and its parameters) + are uniquely determined. + For GLSpace, HPSpace, and LMSpace, a sensible (but not unique) + co-domain is chosen that should work satisfactorily in most + situations. For full control however, the user should not rely on + this method. + + Raises + ------ + ValueError: + if no default codomain is defined for "domain". + + """ + domain_class = domain.__class__ + try: + codomain_class = cls.default_codomain_dictionary[domain_class] + except KeyError: + raise ValueError("Unknown domain") + + try: + transform_class = cls.transformation_dictionary[(domain_class, + codomain_class)] + except KeyError: + raise ValueError( + "No transformation for domain-codomain pair found.") + + return transform_class.get_codomain(domain) diff --git a/nifty/sugar.py b/nifty/sugar.py index 7c1941ef13f26b53cccc23e99059a29caa46fd44..500ae4460a1f823cb6a661d6c959e52e2fd93a58 100644 --- a/nifty/sugar.py +++ b/nifty/sugar.py @@ -71,6 +71,9 @@ def create_power_operator(domain, power_spectrum, dtype=None, distribution_strategy='not') f = fp.power_synthesize(mean=1, std=0, real_signal=False, distribution_strategy=distribution_strategy) + # MR FIXME: we need the real part here. Could this also be achieved + # by setting real_signal=True in the call above? + f = f.real f **= 2 return DiagonalOperator(domain, diagonal=f, bare=True) diff --git a/test/test_operators/test_real_fft_operator.py b/test/test_operators/test_real_fft_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0c63db213b4205a9fc58cf31ec224e49700f16 --- /dev/null +++ b/test/test_operators/test_real_fft_operator.py @@ -0,0 +1,137 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Copyright(C) 2013-2017 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +import unittest +import numpy as np +from numpy.testing import assert_equal,\ + assert_allclose +from nifty.config import dependency_injector as gdi +from nifty import Field,\ + RGSpace,\ + LMSpace,\ + HPSpace,\ + GLSpace,\ + RealFFTOperator +from itertools import product +from test.common import expand +from nose.plugins.skip import SkipTest + + +def _get_rtol(tp): + if (tp == np.float64) or (tp == np.complex128): + return 1e-10 + else: + return 1e-5 + + +class RealFFTOperatorTests(unittest.TestCase): + @expand(product(["numpy", "fftw", "fftw_mpi"], + [16, ], [0.1, 1, 3.7], + [np.float64, np.float32])) + def test_fft1D(self, module, dim1, d, itp): + if module == "fftw_mpi": + if not hasattr(gdi.get('fftw'), 'FFTW_MPI'): + raise SkipTest + if module == "fftw" and "fftw" not in gdi: + raise SkipTest + tol = _get_rtol(itp) + a = RGSpace(dim1, distances=d) + b = RGSpace(dim1, distances=1./(dim1*d), harmonic=True) + fft = RealFFTOperator(domain=a, target=b, module=module) + np.random.seed(16) + inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3, + dtype=itp) + out = fft.adjoint_times(fft.times(inp)) + assert_allclose(inp.val.get_full_data(), + out.val.get_full_data(), + rtol=tol, atol=tol) + + @expand(product(["numpy", "fftw", "fftw_mpi"], + [12, 15], [9, 12], [0.1, 1, 3.7], + [0.4, 1, 2.7], + [np.float64, np.float32])) + def test_fft2D(self, module, dim1, dim2, d1, d2, itp): + if module == "fftw_mpi": + if not hasattr(gdi.get('fftw'), 'FFTW_MPI'): + raise SkipTest + if module == "fftw" and "fftw" not in gdi: + raise SkipTest + tol = _get_rtol(itp) + a = RGSpace([dim1, dim2], distances=[d1, d2]) + b = RGSpace([dim1, dim2], + distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True) + fft = RealFFTOperator(domain=a, target=b, module=module) + inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3, + dtype=itp) + out = fft.adjoint_times(fft.times(inp)) + assert_allclose(inp.val, out.val, rtol=tol, atol=tol) + + @expand(product([0, 3, 6, 11, 30], [np.float64, np.float32])) + def test_sht(self, lm, tp): + if 'pyHealpix' not in gdi: + raise SkipTest + tol = _get_rtol(tp) + a = LMSpace(lmax=lm) + b = GLSpace(nlat=lm+1) + fft = RealFFTOperator(domain=a, target=b) + inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3, + dtype=tp) + out = fft.adjoint_times(fft.times(inp)) + assert_allclose(inp.val, out.val, rtol=tol, atol=tol) + + @expand(product([128, 256], [np.float64, np.float32])) + def test_sht2(self, lm, tp): + if 'pyHealpix' not in gdi: + raise SkipTest + a = LMSpace(lmax=lm) + b = HPSpace(nside=lm//2) + fft = RealFFTOperator(domain=a, target=b) + inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0, + dtype=tp) + out = fft.adjoint_times(fft.times(inp)) + assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1) + + @expand(product([128, 256], [np.float64, np.float32])) + def test_dotsht(self, lm, tp): + if 'pyHealpix' not in gdi: + raise SkipTest + tol = _get_rtol(tp) + a = LMSpace(lmax=lm) + b = GLSpace(nlat=lm+1) + fft = RealFFTOperator(domain=a, target=b) + inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0, + dtype=tp) + out = fft.times(inp) + v1 = np.sqrt(out.vdot(out)) + v2 = np.sqrt(inp.vdot(fft.adjoint_times(out))) + assert_allclose(v1, v2, rtol=tol, atol=tol) + + @expand(product([128, 256], [np.float64, np.float32])) + def test_dotsht2(self, lm, tp): + if 'pyHealpix' not in gdi: + raise SkipTest + tol = _get_rtol(tp) + a = LMSpace(lmax=lm) + b = HPSpace(nside=lm//2) + fft = RealFFTOperator(domain=a, target=b) + inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0, + dtype=tp) + out = fft.times(inp) + v1 = np.sqrt(out.vdot(out)) + v2 = np.sqrt(inp.vdot(fft.adjoint_times(out))) + assert_allclose(v1, v2, rtol=tol, atol=tol)