Commit 666c9396 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

FFTOperator: fix "unitary" property

LinearOperator: better fallback strategy if a specific operation is not defined
test_misc: test condition on scalar products under FFTOperator
parent e7a77f88
......@@ -71,12 +71,11 @@ class FFTOperator(LinearOperator):
on the sphere the default is (unsurprisingly) "pyHealpix".
domain_dtype: data type (optional)
Data type of the fields that go into "times" and come out of
"adjoint_times". Default is "numpy.float".
"adjoint_times". Default is "numpy.complex".
target_dtype: data type (optional)
Data type of the fields that go into "adjoint_times" and come out of
"times". Default is "numpy.complex".
(MR: I feel this is not really a good idea, since it makes no sense for
SHTs. Also, wouldn't it make sense to specify data types
(MR: Wouldn't it make sense to specify data types
only to "times" and "adjoint_times"? Does the operator itself really
need to know this, or only the individual call?)
......@@ -90,11 +89,8 @@ class FFTOperator(LinearOperator):
The domain of the data that is output by "times" and input by
"adjoint_times".
unitary: bool
Returns False.
This is strictly speaking a lie, because FFTOperators on RGSpaces are
in fact unitary ... but if we return True in this case, then
LinearOperator will call _inverse_times instead of _adjoint_times, which
does not exist. This needs some more work.
Returns True if the operator is unitary (currently only the case if
the domain and codomain are RGSpaces), else False.
Raises
------
......@@ -104,6 +100,9 @@ class FFTOperator(LinearOperator):
"""
# ---Class attributes---
# Domains for which FFTOperator is unitary
unitary_list = (RGSpace,)
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
GLSpace: LMSpace,
......@@ -156,8 +155,8 @@ class FFTOperator(LinearOperator):
# RGSpaces.
# Store the dtype information
if domain_dtype is None:
self.logger.info("Setting domain_dtype to np.float.")
self.domain_dtype = np.float
self.logger.info("Setting domain_dtype to np.complex.")
self.domain_dtype = np.complex
else:
self.domain_dtype = np.dtype(domain_dtype)
......@@ -227,7 +226,7 @@ class FFTOperator(LinearOperator):
@property
def unitary(self):
return False
return type(self.domain[0]) in self.unitary_list
# ---Added properties and methods---
......
......@@ -57,34 +57,49 @@ class LinearOperator(Loggable, object):
def inverse_times(self, x, spaces=None, **kwargs):
spaces = self._check_input_compatibility(x, spaces, inverse=True)
try:
y = self._inverse_times(x, spaces, **kwargs)
except(NotImplementedError):
if (self.unitary):
y = self._adjoint_times(x, spaces, **kwargs)
else:
raise
return y
def adjoint_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.inverse_times(x, spaces)
spaces = self._check_input_compatibility(x, spaces, inverse=True)
try:
y = self._adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if (self.unitary):
y = self._inverse_times(x, spaces, **kwargs)
else:
raise
return y
def adjoint_inverse_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.times(x, spaces)
spaces = self._check_input_compatibility(x, spaces)
try:
y = self._adjoint_inverse_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
return y
def inverse_adjoint_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.times(x, spaces, **kwargs)
spaces = self._check_input_compatibility(x, spaces)
y = self._inverse_adjoint_times(x, spaces)
try:
y = self._inverse_adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
return y
def _times(self, x, spaces):
......
......@@ -126,3 +126,35 @@ class Misc_Tests(unittest.TestCase):
dtype=tp)
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in di:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1=np.sqrt(out.dot(out))
v2=np.sqrt(inp.dot(fft.adjoint_times(out)))
assert_allclose(v1,v2, rtol=tol, atol=tol)
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in di:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1=np.sqrt(out.dot(out))
v2=np.sqrt(inp.dot(fft.adjoint_times(out)))
assert_allclose(v1,v2, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment