diff --git a/demos/wiener_filter_via_curvature.py b/demos/wiener_filter_via_curvature.py
index 0224333321b1bf1621bc0c10ae3a759d9d0fb0f6..d3b826cd7c026f019f6101755db3cc3d9b77f781 100644
--- a/demos/wiener_filter_via_curvature.py
+++ b/demos/wiener_filter_via_curvature.py
@@ -13,11 +13,11 @@ if __name__ == "__main__":
# Setting up variable parameters
# Typical distance over which the field is correlated
- correlation_length = 0.01
+ correlation_length = 0.05
# Variance of field in position space sqrt(<|s_x|^2>)
field_variance = 2.
# smoothing length of response (in same unit as L)
- response_sigma = 0.1
+ response_sigma = 0.01
# The signal to noise ratio
signal_to_noise = 0.7
@@ -73,9 +73,10 @@ if __name__ == "__main__":
m_s = fft(m)
plotter = plotting.RG2DPlotter()
- plotter.title = 'mock_signal.html';
+ plotter.path = 'mock_signal.html'
plotter(mock_signal)
- plotter.title = 'data.html'
+ plotter.path = 'data.html'
plotter(Field(signal_space,
val=data.val.get_full_data().reshape(signal_space.shape)))
- plotter.title = 'map.html'; plotter(m_s)
\ No newline at end of file
+ plotter.path = 'map.html'
+ plotter(m_s)
diff --git a/demos/wiener_filter_via_curvature_real.py b/demos/wiener_filter_via_curvature_real.py
new file mode 100644
index 0000000000000000000000000000000000000000..4daed2cf8c8da41a5634b103fc74f18e14e79a20
--- /dev/null
+++ b/demos/wiener_filter_via_curvature_real.py
@@ -0,0 +1,82 @@
+import numpy as np
+
+from nifty import RGSpace, PowerSpace, Field, RealFFTOperator,\
+ ComposedOperator, DiagonalOperator, ResponseOperator,\
+ plotting, create_power_operator
+from nifty.library import WienerFilterCurvature
+
+
+if __name__ == "__main__":
+
+ distribution_strategy = 'not'
+
+ # Setting up variable parameters
+
+ # Typical distance over which the field is correlated
+ correlation_length = 0.05
+ # Variance of field in position space sqrt(<|s_x|^2>)
+ field_variance = 2.
+ # smoothing length of response (in same unit as L)
+ response_sigma = 0.01
+ # The signal to noise ratio
+ signal_to_noise = 0.7
+
+ # note that field_variance**2 = a*k_0/4. for this analytic form of power
+ # spectrum
+ def power_spectrum(k):
+ a = 4 * correlation_length * field_variance**2
+ return a / (1 + k * correlation_length) ** 4
+
+ # Setting up the geometry
+
+ # Total side-length of the domain
+ L = 2.
+ # Grid resolution (pixels per axis)
+ N_pixels = 512
+
+ signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
+ harmonic_space = RealFFTOperator.get_default_codomain(signal_space)
+ fft = RealFFTOperator(harmonic_space, target=signal_space)
+ power_space = PowerSpace(harmonic_space,
+ distribution_strategy=distribution_strategy)
+
+ # Creating the mock data
+ S = create_power_operator(harmonic_space, power_spectrum=power_spectrum,
+ distribution_strategy=distribution_strategy)
+
+ mock_power = Field(power_space, val=power_spectrum,
+ distribution_strategy=distribution_strategy)
+ np.random.seed(43)
+ mock_harmonic = mock_power.power_synthesize(real_signal=True)
+ mock_harmonic = mock_harmonic.real + mock_harmonic.imag
+ mock_signal = fft(mock_harmonic)
+
+ R = ResponseOperator(signal_space, sigma=(response_sigma,))
+ data_domain = R.target[0]
+ R_harmonic = ComposedOperator([fft, R], default_spaces=[0, 0])
+
+ N = DiagonalOperator(data_domain,
+ diagonal=mock_signal.var()/signal_to_noise,
+ bare=True)
+ noise = Field.from_random(domain=data_domain,
+ random_type='normal',
+ std=mock_signal.std()/np.sqrt(signal_to_noise),
+ mean=0)
+ data = R(mock_signal) + noise
+
+ # Wiener filter
+
+ j = R_harmonic.adjoint_times(N.inverse_times(data))
+ wiener_curvature = WienerFilterCurvature(S=S, N=N, R=R_harmonic)
+
+ m = wiener_curvature.inverse_times(j)
+ m_s = fft(m)
+
+ plotter = plotting.RG2DPlotter()
+ plotter.path = 'mock_signal.html'
+ plotter(mock_signal)
+ plotter.path = 'data.html'
+ plotter(Field(signal_space,
+ val=data.val.get_full_data().reshape(signal_space.shape)))
+ plotter.path = 'map.html'
+ plotter(m_s)
diff --git a/nifty/operators/fft_operator/__init__.py b/nifty/operators/fft_operator/__init__.py
index c0247fdece024338baf35fc145f7fb27b0944f7a..53b620e85eddbef6643bc12b5de7ed19fe52cdcf 100644
--- a/nifty/operators/fft_operator/__init__.py
+++ b/nifty/operators/fft_operator/__init__.py
@@ -17,4 +17,5 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from transformations import *
-from fft_operator import FFTOperator
+from .fft_operator import FFTOperator
+from .real_fft_operator import RealFFTOperator
diff --git a/nifty/operators/fft_operator/real_fft_operator.py b/nifty/operators/fft_operator/real_fft_operator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3413c1b875c4d8615393d68d8112d471aeb1fa6
--- /dev/null
+++ b/nifty/operators/fft_operator/real_fft_operator.py
@@ -0,0 +1,254 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+#
+# Copyright(C) 2013-2017 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+import numpy as np
+
+import nifty.nifty_utilities as utilities
+from nifty.spaces import RGSpace,\
+ GLSpace,\
+ HPSpace,\
+ LMSpace
+
+from nifty.operators.linear_operator import LinearOperator
+from transformations import RGRGTransformation,\
+ LMGLTransformation,\
+ LMHPTransformation,\
+ GLLMTransformation,\
+ HPLMTransformation,\
+ TransformationCache
+
+
+class RealFFTOperator(LinearOperator):
+ """Transforms between a pair of position and harmonic domains.
+
+ Built-in domain pairs are
+ - a harmonic and a non-harmonic RGSpace (with matching distances)
+ - a HPSpace and a LMSpace
+ - a GLSpace and a LMSpace
+ Within a domain pair, both orderings are possible.
+
+ The operator provides a "times" and an "adjoint_times" operation.
+ For a pair of RGSpaces, the "adjoint_times" operation is equivalent to
+ "inverse_times"; for the sphere-related domains this is not the case, since
+ the operator matrix is not square.
+
+ In contrast to the FFTOperator, RealFFTOperator accepts and returns
+ real-valued fields only. For the harmonic-space counterpart of a
+ real-valued field living on an RGSpace, the sum of real
+ and imaginary components is stored. Since the full complex field has
+ Hermitian symmetry, this is sufficient to reconstruct the full field
+ whenever needed (e.g. during the transform back to position space).
+
+ Parameters
+ ----------
+ domain: Space or single-element tuple of Spaces
+ The domain of the data that is input by "times" and output by
+ "adjoint_times".
+ target: Space or single-element tuple of Spaces (optional)
+ The domain of the data that is output by "times" and input by
+ "adjoint_times".
+ If omitted, a co-domain will be chosen automatically.
+ Whenever "domain" is an RGSpace, the codomain (and its parameters) are
+ uniquely determined.
+ For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
+ co-domain is chosen that should work satisfactorily in most situations,
+ but for full control, the user should explicitly specify a codomain.
+ module: String (optional)
+ Software module employed for carrying out the transform operations.
+ For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
+ always available (using pyfftw if available, else numpy.fft), and "mpi"
+ requires pyfftw and offers MPI parallelization.
+ For sphere-related domains, only "pyHealpix" is
+ available. If omitted, "fftw" is selected for RGSpaces if available,
+ else "numpy"; on the sphere the default is "pyHealpix".
+
+ Attributes
+ ----------
+ domain: Tuple of Spaces (with one entry)
+ The domain of the data that is input by "times" and output by
+ "adjoint_times".
+ target: Tuple of Spaces (with one entry)
+ The domain of the data that is output by "times" and input by
+ "adjoint_times".
+ unitary: bool
+ Returns True if the operator is unitary (currently only the case if
+ the domain and codomain are RGSpaces), else False.
+
+ Raises
+ ------
+ ValueError:
+ if "domain" or "target" are not of the proper type.
+
+ """
+ default_codomain_dictionary = {RGSpace: RGSpace,
+ HPSpace: LMSpace,
+ GLSpace: LMSpace,
+ LMSpace: GLSpace,
+ }
+
+ transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
+ (HPSpace, LMSpace): HPLMTransformation,
+ (GLSpace, LMSpace): GLLMTransformation,
+ (LMSpace, HPSpace): LMHPTransformation,
+ (LMSpace, GLSpace): LMGLTransformation
+ }
+
+ def __init__(self, domain, target=None, module=None,
+ default_spaces=None):
+ super(RealFFTOperator, self).__init__(default_spaces)
+
+ # Initialize domain and target
+ self._domain = self._parse_domain(domain)
+ if len(self.domain) != 1:
+ raise ValueError("TransformationOperator accepts only exactly one "
+ "space as input domain.")
+
+ if target is None:
+ target = (self.get_default_codomain(self.domain[0]), )
+ self._target = self._parse_domain(target)
+ if len(self.target) != 1:
+ raise ValueError("TransformationOperator accepts only exactly one "
+ "space as output target.")
+
+ # Create transformation instances
+ forward_class = self.transformation_dictionary[
+ (self.domain[0].__class__, self.target[0].__class__)]
+ backward_class = self.transformation_dictionary[
+ (self.target[0].__class__, self.domain[0].__class__)]
+
+ self._forward_transformation = TransformationCache.create(
+ forward_class, self.domain[0], self.target[0], module=module)
+
+ self._backward_transformation = TransformationCache.create(
+ backward_class, self.target[0], self.domain[0], module=module)
+
+ def _prep(self, x, spaces, dom):
+ assert issubclass(x.dtype.type,np.floating), \
+ "Argument must be real-valued"
+ spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
+ if spaces is None:
+ # this case means that x lives on only one space, which is
+ # identical to the space in the domain of `self`. Otherwise the
+ # input check of LinearOperator would have failed.
+ axes = x.domain_axes[0]
+ else:
+ axes = x.domain_axes[spaces[0]]
+
+ if spaces is None:
+ result_domain = dom
+ else:
+ result_domain = list(x.domain)
+ result_domain[spaces[0]] = dom[0]
+
+ result_field = x.copy_empty(domain=result_domain, dtype=x.dtype)
+ return spaces, axes, result_field
+
+ def _times(self, x, spaces):
+ spaces, axes, result_field = self._prep(x, spaces, self.target)
+
+ if type(self._domain[0]) != RGSpace:
+ new_val = self._forward_transformation.transform(x.val, axes=axes)
+ result_field.set_val(new_val=new_val, copy=True)
+ return result_field
+
+ if self._target[0].harmonic: # going to harmonic space
+ new_val = self._forward_transformation.transform(x.val, axes=axes)
+ result_field.set_val(new_val=new_val.real+new_val.imag)
+ else:
+ tval = self._domain[0].hermitianize_inverter(x.val, axes)
+ tval = 0.5*((x.val+tval)+1j*(x.val-tval))
+ new_val = self._forward_transformation.transform(tval, axes=axes)
+ result_field.set_val(new_val=new_val.real)
+ return result_field
+
+ def _adjoint_times(self, x, spaces):
+ spaces, axes, result_field = self._prep(x, spaces, self.domain)
+
+ if type(self._domain[0]) != RGSpace:
+ new_val = self._backward_transformation.transform(x.val, axes=axes)
+ result_field.set_val(new_val=new_val, copy=True)
+ return result_field
+
+ if self._domain[0].harmonic: # going to harmonic space
+ new_val = self._backward_transformation.transform(x.val, axes=axes)
+ result_field.set_val(new_val=new_val.real+new_val.imag)
+ else:
+ tval = self._target[0].hermitianize_inverter(x.val, axes)
+ tval = 0.5*((x.val+tval)+1j*(x.val-tval))
+ new_val = self._backward_transformation.transform(tval, axes=axes)
+ result_field.set_val(new_val=new_val.real)
+ return result_field
+
+ # ---Mandatory properties and methods---
+
+ @property
+ def domain(self):
+ return self._domain
+
+ @property
+ def target(self):
+ return self._target
+
+ @property
+ def unitary(self):
+ return (self._forward_transformation.unitary and
+ self._backward_transformation.unitary)
+
+ # ---Added properties and methods---
+
+ @classmethod
+ def get_default_codomain(cls, domain):
+ """Returns a codomain to the given domain.
+
+ Parameters
+ ----------
+ domain: Space
+ An instance of RGSpace, HPSpace, GLSpace or LMSpace.
+
+ Returns
+ -------
+ target: Space
+ A (more or less perfect) counterpart to "domain" with respect
+ to a FFT operation.
+ Whenever "domain" is an RGSpace, the codomain (and its parameters)
+ are uniquely determined.
+ For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
+ co-domain is chosen that should work satisfactorily in most
+ situations. For full control however, the user should not rely on
+ this method.
+
+ Raises
+ ------
+ ValueError:
+ if no default codomain is defined for "domain".
+
+ """
+ domain_class = domain.__class__
+ try:
+ codomain_class = cls.default_codomain_dictionary[domain_class]
+ except KeyError:
+ raise ValueError("Unknown domain")
+
+ try:
+ transform_class = cls.transformation_dictionary[(domain_class,
+ codomain_class)]
+ except KeyError:
+ raise ValueError(
+ "No transformation for domain-codomain pair found.")
+
+ return transform_class.get_codomain(domain)
diff --git a/nifty/sugar.py b/nifty/sugar.py
index 7c1941ef13f26b53cccc23e99059a29caa46fd44..500ae4460a1f823cb6a661d6c959e52e2fd93a58 100644
--- a/nifty/sugar.py
+++ b/nifty/sugar.py
@@ -71,6 +71,9 @@ def create_power_operator(domain, power_spectrum, dtype=None,
distribution_strategy='not')
f = fp.power_synthesize(mean=1, std=0, real_signal=False,
distribution_strategy=distribution_strategy)
+ # MR FIXME: we need the real part here. Could this also be achieved
+ # by setting real_signal=True in the call above?
+ f = f.real
f **= 2
return DiagonalOperator(domain, diagonal=f, bare=True)
diff --git a/test/test_operators/test_real_fft_operator.py b/test/test_operators/test_real_fft_operator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0c63db213b4205a9fc58cf31ec224e49700f16
--- /dev/null
+++ b/test/test_operators/test_real_fft_operator.py
@@ -0,0 +1,137 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+#
+# Copyright(C) 2013-2017 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+import unittest
+import numpy as np
+from numpy.testing import assert_equal,\
+ assert_allclose
+from nifty.config import dependency_injector as gdi
+from nifty import Field,\
+ RGSpace,\
+ LMSpace,\
+ HPSpace,\
+ GLSpace,\
+ RealFFTOperator
+from itertools import product
+from test.common import expand
+from nose.plugins.skip import SkipTest
+
+
+def _get_rtol(tp):
+ if (tp == np.float64) or (tp == np.complex128):
+ return 1e-10
+ else:
+ return 1e-5
+
+
+class RealFFTOperatorTests(unittest.TestCase):
+ @expand(product(["numpy", "fftw", "fftw_mpi"],
+ [16, ], [0.1, 1, 3.7],
+ [np.float64, np.float32]))
+ def test_fft1D(self, module, dim1, d, itp):
+ if module == "fftw_mpi":
+ if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
+ raise SkipTest
+ if module == "fftw" and "fftw" not in gdi:
+ raise SkipTest
+ tol = _get_rtol(itp)
+ a = RGSpace(dim1, distances=d)
+ b = RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
+ fft = RealFFTOperator(domain=a, target=b, module=module)
+ np.random.seed(16)
+ inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
+ dtype=itp)
+ out = fft.adjoint_times(fft.times(inp))
+ assert_allclose(inp.val.get_full_data(),
+ out.val.get_full_data(),
+ rtol=tol, atol=tol)
+
+ @expand(product(["numpy", "fftw", "fftw_mpi"],
+ [12, 15], [9, 12], [0.1, 1, 3.7],
+ [0.4, 1, 2.7],
+ [np.float64, np.float32]))
+ def test_fft2D(self, module, dim1, dim2, d1, d2, itp):
+ if module == "fftw_mpi":
+ if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
+ raise SkipTest
+ if module == "fftw" and "fftw" not in gdi:
+ raise SkipTest
+ tol = _get_rtol(itp)
+ a = RGSpace([dim1, dim2], distances=[d1, d2])
+ b = RGSpace([dim1, dim2],
+ distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
+ fft = RealFFTOperator(domain=a, target=b, module=module)
+ inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
+ dtype=itp)
+ out = fft.adjoint_times(fft.times(inp))
+ assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
+
+ @expand(product([0, 3, 6, 11, 30], [np.float64, np.float32]))
+ def test_sht(self, lm, tp):
+ if 'pyHealpix' not in gdi:
+ raise SkipTest
+ tol = _get_rtol(tp)
+ a = LMSpace(lmax=lm)
+ b = GLSpace(nlat=lm+1)
+ fft = RealFFTOperator(domain=a, target=b)
+ inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
+ dtype=tp)
+ out = fft.adjoint_times(fft.times(inp))
+ assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
+
+ @expand(product([128, 256], [np.float64, np.float32]))
+ def test_sht2(self, lm, tp):
+ if 'pyHealpix' not in gdi:
+ raise SkipTest
+ a = LMSpace(lmax=lm)
+ b = HPSpace(nside=lm//2)
+ fft = RealFFTOperator(domain=a, target=b)
+ inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
+ dtype=tp)
+ out = fft.adjoint_times(fft.times(inp))
+ assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
+
+ @expand(product([128, 256], [np.float64, np.float32]))
+ def test_dotsht(self, lm, tp):
+ if 'pyHealpix' not in gdi:
+ raise SkipTest
+ tol = _get_rtol(tp)
+ a = LMSpace(lmax=lm)
+ b = GLSpace(nlat=lm+1)
+ fft = RealFFTOperator(domain=a, target=b)
+ inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
+ dtype=tp)
+ out = fft.times(inp)
+ v1 = np.sqrt(out.vdot(out))
+ v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
+ assert_allclose(v1, v2, rtol=tol, atol=tol)
+
+ @expand(product([128, 256], [np.float64, np.float32]))
+ def test_dotsht2(self, lm, tp):
+ if 'pyHealpix' not in gdi:
+ raise SkipTest
+ tol = _get_rtol(tp)
+ a = LMSpace(lmax=lm)
+ b = HPSpace(nside=lm//2)
+ fft = RealFFTOperator(domain=a, target=b)
+ inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
+ dtype=tp)
+ out = fft.times(inp)
+ v1 = np.sqrt(out.vdot(out))
+ v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
+ assert_allclose(v1, v2, rtol=tol, atol=tol)