Commit 8ad1b902 authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'fft_op' into 'master'

Document FFTOperator and rename its "inverse_times" to "adjoint_times"

See merge request !84
parents 83faac66 f9583d26
Pipeline #12217 passed with stages
in 12 minutes and 9 seconds
......@@ -34,13 +34,72 @@ from transformations import RGRGTransformation,\
class FFTOperator(LinearOperator):
"""Transforms between a pair of position and harmonic domains.
Built-in domain pairs are
- a harmonic and a non-harmonic RGSpace (with matching distances)
- a HPSpace and a LMSpace
- a GLSpace and a LMSpace
Within a domain pair, both orderings are possible.
The operator provides a "times" and an "adjoint_times" operation.
For a pair of RGSpaces, the "adjoint_times" operation is equivalent to
"inverse_times"; for the sphere-related domains this is not the case, since
the operator matrix is not square.
Parameters
----------
domain: Space or single-element tuple of Spaces
The domain of the data that is input by "times" and output by
"adjoint_times".
target: Space or single-element tuple of Spaces (optional)
The domain of the data that is output by "times" and input by
"adjoint_times".
If omitted, a co-domain will be chosen automatically.
Whenever "domain" is an RGSpace, the codomain (and its parameters) are
uniquely determined (except for "zerocenter").
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most situations,
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is
always available, but "fftw" offers higher performance and
parallelization. For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional)
Data type of the fields that go into "times" and come out of
"adjoint_times". Default is "numpy.complex".
target_dtype: data type (optional)
Data type of the fields that go into "adjoint_times" and come out of
"times". Default is "numpy.complex".
Attributes
----------
domain: Tuple of Spaces (with one entry)
The domain of the data that is input by "times" and output by
"adjoint_times".
target: Tuple of Spaces (with one entry)
The domain of the data that is output by "times" and input by
"adjoint_times".
unitary: bool
Returns True if the operator is unitary (currently only the case if
the domain and codomain are RGSpaces), else False.
Raises
------
ValueError:
if "domain" or "target" are not of the proper type.
"""
# ---Class attributes---
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
GLSpace: LMSpace,
LMSpace: HPSpace,
LMSpace: GLSpace,
}
transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
......@@ -52,7 +111,7 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), target=None, module=None,
def __init__(self, domain, target=None, module=None,
domain_dtype=None, target_dtype=None):
# Initialize domain and target
......@@ -83,8 +142,8 @@ class FFTOperator(LinearOperator):
# Store the dtype information
if domain_dtype is None:
self.logger.info("Setting domain_dtype to np.float.")
self.domain_dtype = np.float
self.logger.info("Setting domain_dtype to np.complex.")
self.domain_dtype = np.complex
else:
self.domain_dtype = np.dtype(domain_dtype)
......@@ -118,7 +177,7 @@ class FFTOperator(LinearOperator):
return result_field
def _inverse_times(self, x, spaces):
def _adjoint_times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
# this case means that x lives on only one space, which is
......@@ -154,12 +213,38 @@ class FFTOperator(LinearOperator):
@property
def unitary(self):
return True
return (self._forward_transformation.unitary and
self._backward_transformation.unitary)
# ---Added properties and methods---
@classmethod
def get_default_codomain(cls, domain):
"""Returns a codomain to the given domain.
Parameters
----------
domain: Space
An instance of RGSpace, HPSpace, GLSpace or LMSpace.
Returns
-------
target: Space
A (more or less perfect) counterpart to "domain" with respect
to a FFT operation.
Whenever "domain" is an RGSpace, the codomain (and its parameters)
are uniquely determined (except for "zerocenter").
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most
situations. For full control however, the user should not rely on
this method.
Raises
------
ValueError:
if no default codomain is defined for "domain".
"""
domain_class = domain.__class__
try:
codomain_class = cls.default_codomain_dictionary[domain_class]
......
......@@ -45,6 +45,10 @@ class GLLMTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......
......@@ -46,6 +46,10 @@ class HPLMTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......@@ -98,7 +102,7 @@ class HPLMTransformation(SlicingTransformation):
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal,
resultImag] = [pyHealpix.map2alm_iter(x, lmax, mmax, 3)
resultImag] = [pyHealpix.map2alm(x, lmax, mmax)
for x in (inp.real, inp.imag)]
[resultReal,
......@@ -108,7 +112,7 @@ class HPLMTransformation(SlicingTransformation):
result = self._combine_complex_result(resultReal, resultImag)
else:
result = pyHealpix.map2alm_iter(inp, lmax, mmax, 3)
result = pyHealpix.map2alm(inp, lmax, mmax)
result = lm_transformation_helper.buildIdx(result, lmax=lmax)
return result
......@@ -45,6 +45,10 @@ class LMGLTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......
......@@ -44,6 +44,10 @@ class LMHPTransformation(SlicingTransformation):
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
@classmethod
def get_codomain(cls, domain):
"""
......
......@@ -23,6 +23,9 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain, module)
......@@ -42,6 +45,12 @@ class RGRGTransformation(Transformation):
else:
raise ValueError('Unsupported FFT module:' + module)
# ---Mandatory properties and methods---
@property
def unitary(self):
return True
@classmethod
def get_codomain(cls, domain, zerocenter=None):
"""
......
......@@ -37,6 +37,10 @@ class Transformation(Loggable, object):
self.domain = domain
self.codomain = codomain
@abc.abstractproperty
def unitary(self):
raise NotImplementedError
@classmethod
def get_codomain(cls, domain):
raise NotImplementedError
......
......@@ -57,34 +57,49 @@ class LinearOperator(Loggable, object):
def inverse_times(self, x, spaces=None, **kwargs):
spaces = self._check_input_compatibility(x, spaces, inverse=True)
y = self._inverse_times(x, spaces, **kwargs)
try:
y = self._inverse_times(x, spaces, **kwargs)
except(NotImplementedError):
if (self.unitary):
y = self._adjoint_times(x, spaces, **kwargs)
else:
raise
return y
def adjoint_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.inverse_times(x, spaces)
spaces = self._check_input_compatibility(x, spaces, inverse=True)
y = self._adjoint_times(x, spaces, **kwargs)
try:
y = self._adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if (self.unitary):
y = self._inverse_times(x, spaces, **kwargs)
else:
raise
return y
def adjoint_inverse_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.times(x, spaces)
spaces = self._check_input_compatibility(x, spaces)
y = self._adjoint_inverse_times(x, spaces, **kwargs)
try:
y = self._adjoint_inverse_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
return y
def inverse_adjoint_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.times(x, spaces, **kwargs)
spaces = self._check_input_compatibility(x, spaces)
y = self._inverse_adjoint_times(x, spaces)
try:
y = self._inverse_adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
return y
def _times(self, x, spaces):
......
......@@ -26,9 +26,8 @@ from nifty.config import dependency_injector as di
from nifty import Field,\
RGSpace,\
LMSpace,\
RGRGTransformation, \
LMGLTransformation, \
LMHPTransformation, \
HPSpace,\
GLSpace,\
FFTOperator
from itertools import product
......@@ -75,28 +74,30 @@ class Misc_Tests(unittest.TestCase):
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d)
b = RGRGTransformation.get_codomain(a, zerocenter=zc2)
b = RGSpace(dim1, zerocenter=zc2,distances=1./(dim1*d),harmonic=True)
fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
target_dtype=_harmonic_type(itp), module=module)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=itp)
out = fft.inverse_times(fft.times(inp))
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw"], [10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d, itp):
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "fftw" and "pyfftw" not in di:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=d)
b = RGRGTransformation.get_codomain(a, zerocenter=[zc3, zc4])
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1,d2])
b = RGSpace([dim1, dim2], zerocenter=[zc3, zc4],
distances=[1./(dim1*d1),1./(dim2*d2)],harmonic=True)
fft = FFTOperator(domain=a, target=b, domain_dtype=itp,
target_dtype=_harmonic_type(itp), module=module)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=itp)
out = fft.inverse_times(fft.times(inp))
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product([0, 3, 6, 11, 30],
......@@ -106,11 +107,11 @@ class Misc_Tests(unittest.TestCase):
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = LMGLTransformation.get_codomain(a)
b = GLSpace(nlat=lm+1)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=tp)
out = fft.inverse_times(fft.times(inp))
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product([128, 256],
......@@ -119,9 +120,41 @@ class Misc_Tests(unittest.TestCase):
if 'pyHealpix' not in di:
raise SkipTest
a = LMSpace(lmax=lm)
b = LMHPTransformation.get_codomain(a)
b = HPSpace(nside=lm//2)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-3)
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in di:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1=np.sqrt(out.dot(out))
v2=np.sqrt(inp.dot(fft.adjoint_times(out)))
assert_allclose(v1,v2, rtol=tol, atol=tol)
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in di:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1=np.sqrt(out.dot(out))
v2=np.sqrt(inp.dot(fft.adjoint_times(out)))
assert_allclose(v1,v2, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment