Commit 666c9396 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

FFTOperator: fix "unitary" property

LinearOperator: better fallback strategy if a specific operation is not defined
test_misc: test condition on scalar products under FFTOperator
parent e7a77f88
...@@ -71,12 +71,11 @@ class FFTOperator(LinearOperator): ...@@ -71,12 +71,11 @@ class FFTOperator(LinearOperator):
on the sphere the default is (unsurprisingly) "pyHealpix". on the sphere the default is (unsurprisingly) "pyHealpix".
domain_dtype: data type (optional) domain_dtype: data type (optional)
Data type of the fields that go into "times" and come out of Data type of the fields that go into "times" and come out of
"adjoint_times". Default is "numpy.float". "adjoint_times". Default is "numpy.complex".
target_dtype: data type (optional) target_dtype: data type (optional)
Data type of the fields that go into "adjoint_times" and come out of Data type of the fields that go into "adjoint_times" and come out of
"times". Default is "numpy.complex". "times". Default is "numpy.complex".
(MR: I feel this is not really a good idea, since it makes no sense for (MR: Wouldn't it make sense to specify data types
SHTs. Also, wouldn't it make sense to specify data types
only to "times" and "adjoint_times"? Does the operator itself really only to "times" and "adjoint_times"? Does the operator itself really
need to know this, or only the individual call?) need to know this, or only the individual call?)
...@@ -90,11 +89,8 @@ class FFTOperator(LinearOperator): ...@@ -90,11 +89,8 @@ class FFTOperator(LinearOperator):
The domain of the data that is output by "times" and input by The domain of the data that is output by "times" and input by
"adjoint_times". "adjoint_times".
unitary: bool unitary: bool
Returns False. Returns True if the operator is unitary (currently only the case if
This is strictly speaking a lie, because FFTOperators on RGSpaces are the domain and codomain are RGSpaces), else False.
in fact unitary ... but if we return True in this case, then
LinearOperator will call _inverse_times instead of _adjoint_times, which
does not exist. This needs some more work.
Raises Raises
------ ------
...@@ -104,6 +100,9 @@ class FFTOperator(LinearOperator): ...@@ -104,6 +100,9 @@ class FFTOperator(LinearOperator):
""" """
# ---Class attributes--- # ---Class attributes---
# Domains for which FFTOperator is unitary
unitary_list = (RGSpace,)
default_codomain_dictionary = {RGSpace: RGSpace, default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace, HPSpace: LMSpace,
GLSpace: LMSpace, GLSpace: LMSpace,
...@@ -156,8 +155,8 @@ class FFTOperator(LinearOperator): ...@@ -156,8 +155,8 @@ class FFTOperator(LinearOperator):
# RGSpaces. # RGSpaces.
# Store the dtype information # Store the dtype information
if domain_dtype is None: if domain_dtype is None:
self.logger.info("Setting domain_dtype to np.float.") self.logger.info("Setting domain_dtype to np.complex.")
self.domain_dtype = np.float self.domain_dtype = np.complex
else: else:
self.domain_dtype = np.dtype(domain_dtype) self.domain_dtype = np.dtype(domain_dtype)
...@@ -227,7 +226,7 @@ class FFTOperator(LinearOperator): ...@@ -227,7 +226,7 @@ class FFTOperator(LinearOperator):
@property @property
def unitary(self): def unitary(self):
return False return type(self.domain[0]) in self.unitary_list
# ---Added properties and methods--- # ---Added properties and methods---
......
...@@ -57,34 +57,49 @@ class LinearOperator(Loggable, object): ...@@ -57,34 +57,49 @@ class LinearOperator(Loggable, object):
def inverse_times(self, x, spaces=None, **kwargs): def inverse_times(self, x, spaces=None, **kwargs):
spaces = self._check_input_compatibility(x, spaces, inverse=True) spaces = self._check_input_compatibility(x, spaces, inverse=True)
y = self._inverse_times(x, spaces, **kwargs) try:
y = self._inverse_times(x, spaces, **kwargs)
except(NotImplementedError):
if (self.unitary):
y = self._adjoint_times(x, spaces, **kwargs)
else:
raise
return y return y
def adjoint_times(self, x, spaces=None, **kwargs): def adjoint_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.inverse_times(x, spaces)
spaces = self._check_input_compatibility(x, spaces, inverse=True) spaces = self._check_input_compatibility(x, spaces, inverse=True)
y = self._adjoint_times(x, spaces, **kwargs) try:
y = self._adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if (self.unitary):
y = self._inverse_times(x, spaces, **kwargs)
else:
raise
return y return y
def adjoint_inverse_times(self, x, spaces=None, **kwargs): def adjoint_inverse_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.times(x, spaces)
spaces = self._check_input_compatibility(x, spaces) spaces = self._check_input_compatibility(x, spaces)
y = self._adjoint_inverse_times(x, spaces, **kwargs) try:
y = self._adjoint_inverse_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
return y return y
def inverse_adjoint_times(self, x, spaces=None, **kwargs): def inverse_adjoint_times(self, x, spaces=None, **kwargs):
if self.unitary:
return self.times(x, spaces, **kwargs)
spaces = self._check_input_compatibility(x, spaces) spaces = self._check_input_compatibility(x, spaces)
y = self._inverse_adjoint_times(x, spaces) try:
y = self._inverse_adjoint_times(x, spaces, **kwargs)
except(NotImplementedError):
if self.unitary:
y = self._times(x, spaces, **kwargs)
else:
raise
return y return y
def _times(self, x, spaces): def _times(self, x, spaces):
......
...@@ -126,3 +126,35 @@ class Misc_Tests(unittest.TestCase): ...@@ -126,3 +126,35 @@ class Misc_Tests(unittest.TestCase):
dtype=tp) dtype=tp)
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1) assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in di:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1=np.sqrt(out.dot(out))
v2=np.sqrt(inp.dot(fft.adjoint_times(out)))
assert_allclose(v1,v2, rtol=tol, atol=tol)
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in di:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp)
out = fft.times(inp)
v1=np.sqrt(out.dot(out))
v2=np.sqrt(inp.dot(fft.adjoint_times(out)))
assert_allclose(v1,v2, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment