Commit b3fdc443 authored by Martin Reinecke's avatar Martin Reinecke

introducing the RealFFTOperator

parent 89ff8afe
Pipeline #16510 passed with stage
in 9 minutes and 30 seconds
......@@ -13,11 +13,11 @@ if __name__ == "__main__":
# Setting up variable parameters
# Typical distance over which the field is correlated
correlation_length = 0.01
correlation_length = 0.05
# Variance of field in position space sqrt(<|s_x|^2>)
field_variance = 2.
# smoothing length of response (in same unit as L)
response_sigma = 0.1
response_sigma = 0.01
# The signal to noise ratio
signal_to_noise = 0.7
......@@ -73,9 +73,10 @@ if __name__ == "__main__":
m_s = fft(m)
plotter = plotting.RG2DPlotter()
plotter.title = 'mock_signal.html';
plotter.path = 'mock_signal.html'
plotter(mock_signal)
plotter.title = 'data.html'
plotter.path = 'data.html'
plotter(Field(signal_space,
val=data.val.get_full_data().reshape(signal_space.shape)))
plotter.title = 'map.html'; plotter(m_s)
\ No newline at end of file
plotter.path = 'map.html'
plotter(m_s)
import numpy as np
from nifty import RGSpace, PowerSpace, Field, RealFFTOperator,\
ComposedOperator, DiagonalOperator, ResponseOperator,\
plotting, create_power_operator
from nifty.library import WienerFilterCurvature
if __name__ == "__main__":
distribution_strategy = 'not'
# Setting up variable parameters
# Typical distance over which the field is correlated
correlation_length = 0.05
# Variance of field in position space sqrt(<|s_x|^2>)
field_variance = 2.
# smoothing length of response (in same unit as L)
response_sigma = 0.01
# The signal to noise ratio
signal_to_noise = 0.7
# note that field_variance**2 = a*k_0/4. for this analytic form of power
# spectrum
def power_spectrum(k):
a = 4 * correlation_length * field_variance**2
return a / (1 + k * correlation_length) ** 4
# Setting up the geometry
# Total side-length of the domain
L = 2.
# Grid resolution (pixels per axis)
N_pixels = 512
signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
harmonic_space = RealFFTOperator.get_default_codomain(signal_space)
fft = RealFFTOperator(harmonic_space, target=signal_space)
power_space = PowerSpace(harmonic_space,
distribution_strategy=distribution_strategy)
# Creating the mock data
S = create_power_operator(harmonic_space, power_spectrum=power_spectrum,
distribution_strategy=distribution_strategy)
mock_power = Field(power_space, val=power_spectrum,
distribution_strategy=distribution_strategy)
np.random.seed(43)
mock_harmonic = mock_power.power_synthesize(real_signal=True)
mock_harmonic = mock_harmonic.real + mock_harmonic.imag
mock_signal = fft(mock_harmonic)
R = ResponseOperator(signal_space, sigma=(response_sigma,))
data_domain = R.target[0]
R_harmonic = ComposedOperator([fft, R], default_spaces=[0, 0])
N = DiagonalOperator(data_domain,
diagonal=mock_signal.var()/signal_to_noise,
bare=True)
noise = Field.from_random(domain=data_domain,
random_type='normal',
std=mock_signal.std()/np.sqrt(signal_to_noise),
mean=0)
data = R(mock_signal) + noise
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
wiener_curvature = WienerFilterCurvature(S=S, N=N, R=R_harmonic)
m = wiener_curvature.inverse_times(j)
m_s = fft(m)
plotter = plotting.RG2DPlotter()
plotter.path = 'mock_signal.html'
plotter(mock_signal)
plotter.path = 'data.html'
plotter(Field(signal_space,
val=data.val.get_full_data().reshape(signal_space.shape)))
plotter.path = 'map.html'
plotter(m_s)
......@@ -17,4 +17,5 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from transformations import *
from fft_operator import FFTOperator
from .fft_operator import FFTOperator
from .real_fft_operator import RealFFTOperator
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
import nifty.nifty_utilities as utilities
from nifty.spaces import RGSpace,\
GLSpace,\
HPSpace,\
LMSpace
from nifty.operators.linear_operator import LinearOperator
from transformations import RGRGTransformation,\
LMGLTransformation,\
LMHPTransformation,\
GLLMTransformation,\
HPLMTransformation,\
TransformationCache
class RealFFTOperator(LinearOperator):
"""Transforms between a pair of position and harmonic domains.
Built-in domain pairs are
- a harmonic and a non-harmonic RGSpace (with matching distances)
- a HPSpace and a LMSpace
- a GLSpace and a LMSpace
Within a domain pair, both orderings are possible.
The operator provides a "times" and an "adjoint_times" operation.
For a pair of RGSpaces, the "adjoint_times" operation is equivalent to
"inverse_times"; for the sphere-related domains this is not the case, since
the operator matrix is not square.
In contrast to the FFTOperator, RealFFTOperator accepts and returns
real-valued fields only. For the harmonic-space counterpart of a
real-valued field living on an RGSpace, the sum of real
and imaginary components is stored. Since the full complex field has
Hermitian symmetry, this is sufficient to reconstruct the full field
whenever needed (e.g. during the transform back to position space).
Parameters
----------
domain: Space or single-element tuple of Spaces
The domain of the data that is input by "times" and output by
"adjoint_times".
target: Space or single-element tuple of Spaces (optional)
The domain of the data that is output by "times" and input by
"adjoint_times".
If omitted, a co-domain will be chosen automatically.
Whenever "domain" is an RGSpace, the codomain (and its parameters) are
uniquely determined.
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most situations,
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
always available (using pyfftw if available, else numpy.fft), and "mpi"
requires pyfftw and offers MPI parallelization.
For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
Attributes
----------
domain: Tuple of Spaces (with one entry)
The domain of the data that is input by "times" and output by
"adjoint_times".
target: Tuple of Spaces (with one entry)
The domain of the data that is output by "times" and input by
"adjoint_times".
unitary: bool
Returns True if the operator is unitary (currently only the case if
the domain and codomain are RGSpaces), else False.
Raises
------
ValueError:
if "domain" or "target" are not of the proper type.
"""
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
GLSpace: LMSpace,
LMSpace: GLSpace,
}
transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
(HPSpace, LMSpace): HPLMTransformation,
(GLSpace, LMSpace): GLLMTransformation,
(LMSpace, HPSpace): LMHPTransformation,
(LMSpace, GLSpace): LMGLTransformation
}
def __init__(self, domain, target=None, module=None,
default_spaces=None):
super(RealFFTOperator, self).__init__(default_spaces)
# Initialize domain and target
self._domain = self._parse_domain(domain)
if len(self.domain) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as input domain.")
if target is None:
target = (self.get_default_codomain(self.domain[0]), )
self._target = self._parse_domain(target)
if len(self.target) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as output target.")
# Create transformation instances
forward_class = self.transformation_dictionary[
(self.domain[0].__class__, self.target[0].__class__)]
backward_class = self.transformation_dictionary[
(self.target[0].__class__, self.domain[0].__class__)]
self._forward_transformation = TransformationCache.create(
forward_class, self.domain[0], self.target[0], module=module)
self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module)
def _prep(self, x, spaces, dom):
assert issubclass(x.dtype.type,np.floating), \
"Argument must be real-valued"
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
else:
axes = x.domain_axes[spaces[0]]
if spaces is None:
result_domain = dom
else:
result_domain = list(x.domain)
result_domain[spaces[0]] = dom[0]
result_field = x.copy_empty(domain=result_domain, dtype=x.dtype)
return spaces, axes, result_field
def _times(self, x, spaces):
spaces, axes, result_field = self._prep(x, spaces, self.target)
if type(self._domain[0]) != RGSpace:
new_val = self._forward_transformation.transform(x.val, axes=axes)
result_field.set_val(new_val=new_val, copy=True)
return result_field
if self._target[0].harmonic: # going to harmonic space
new_val = self._forward_transformation.transform(x.val, axes=axes)
result_field.set_val(new_val=new_val.real+new_val.imag)
else:
tval = self._domain[0].hermitianize_inverter(x.val, axes)
tval = 0.5*((x.val+tval)+1j*(x.val-tval))
new_val = self._forward_transformation.transform(tval, axes=axes)
result_field.set_val(new_val=new_val.real)
return result_field
def _adjoint_times(self, x, spaces):
spaces, axes, result_field = self._prep(x, spaces, self.domain)
if type(self._domain[0]) != RGSpace:
new_val = self._backward_transformation.transform(x.val, axes=axes)
result_field.set_val(new_val=new_val, copy=True)
return result_field
if self._domain[0].harmonic: # going to harmonic space
new_val = self._backward_transformation.transform(x.val, axes=axes)
result_field.set_val(new_val=new_val.real+new_val.imag)
else:
tval = self._target[0].hermitianize_inverter(x.val, axes)
tval = 0.5*((x.val+tval)+1j*(x.val-tval))
new_val = self._backward_transformation.transform(tval, axes=axes)
result_field.set_val(new_val=new_val.real)
return result_field
# ---Mandatory properties and methods---
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def unitary(self):
return (self._forward_transformation.unitary and
self._backward_transformation.unitary)
# ---Added properties and methods---
@classmethod
def get_default_codomain(cls, domain):
"""Returns a codomain to the given domain.
Parameters
----------
domain: Space
An instance of RGSpace, HPSpace, GLSpace or LMSpace.
Returns
-------
target: Space
A (more or less perfect) counterpart to "domain" with respect
to a FFT operation.
Whenever "domain" is an RGSpace, the codomain (and its parameters)
are uniquely determined.
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most
situations. For full control however, the user should not rely on
this method.
Raises
------
ValueError:
if no default codomain is defined for "domain".
"""
domain_class = domain.__class__
try:
codomain_class = cls.default_codomain_dictionary[domain_class]
except KeyError:
raise ValueError("Unknown domain")
try:
transform_class = cls.transformation_dictionary[(domain_class,
codomain_class)]
except KeyError:
raise ValueError(
"No transformation for domain-codomain pair found.")
return transform_class.get_codomain(domain)
......@@ -71,6 +71,9 @@ def create_power_operator(domain, power_spectrum, dtype=None,
distribution_strategy='not')
f = fp.power_synthesize(mean=1, std=0, real_signal=False,
distribution_strategy=distribution_strategy)
# MR FIXME: we need the real part here. Could this also be achieved
# by setting real_signal=True in the call above?
f = f.real
f **= 2
return DiagonalOperator(domain, diagonal=f, bare=True)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
import numpy as np
from numpy.testing import assert_equal,\
assert_allclose
from nifty.config import dependency_injector as gdi
from nifty import Field,\
RGSpace,\
LMSpace,\
HPSpace,\
GLSpace,\
RealFFTOperator
from itertools import product
from test.common import expand
from nose.plugins.skip import SkipTest
def _get_rtol(tp):
if (tp == np.float64) or (tp == np.complex128):
return 1e-10
else:
return 1e-5
class RealFFTOperatorTests(unittest.TestCase):
@expand(product(["numpy", "fftw", "fftw_mpi"],
[16, ], [0.1, 1, 3.7],
[np.float64, np.float32]))
def test_fft1D(self, module, dim1, d, itp):
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace(dim1, distances=d)
b = RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
fft = RealFFTOperator(domain=a, target=b, module=module)
np.random.seed(16)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=itp)
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val.get_full_data(),
out.val.get_full_data(),
rtol=tol, atol=tol)
@expand(product(["numpy", "fftw", "fftw_mpi"],
[12, 15], [9, 12], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.float32]))
def test_fft2D(self, module, dim1, dim2, d1, d2, itp):
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], distances=[d1, d2])
b = RGSpace([dim1, dim2],
distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
fft = RealFFTOperator(domain=a, target=b, module=module)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=itp)
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product([0, 3, 6, 11, 30], [np.float64, np.float32]))
def test_sht(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1)
fft = RealFFTOperator(domain=a, target=b)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=tp)
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product([128, 256], [np.float64, np.float32]))
def test_sht2(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
fft = RealFFTOperator(domain=a, target=b)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
@expand(product([128, 256], [np.float64, np.float32]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1)
fft = RealFFTOperator(domain=a, target=b)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1 = np.sqrt(out.vdot(out))
v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
assert_allclose(v1, v2, rtol=tol, atol=tol)
@expand(product([128, 256], [np.float64, np.float32]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
fft = RealFFTOperator(domain=a, target=b)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1 = np.sqrt(out.vdot(out))
v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
assert_allclose(v1, v2, rtol=tol, atol=tol)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment