Commit fd98cba2 authored by Theo Steininger's avatar Theo Steininger

Integrated real harmonic RGSpace representation deeply into NIFTy.

parent 39b6628e
Pipeline #16544 failed with stage
in 6 minutes and 17 seconds
...@@ -2,13 +2,14 @@ import numpy as np ...@@ -2,13 +2,14 @@ import numpy as np
from nifty import RGSpace, PowerSpace, Field, FFTOperator, ComposedOperator,\ from nifty import RGSpace, PowerSpace, Field, FFTOperator, ComposedOperator,\
DiagonalOperator, ResponseOperator, plotting,\ DiagonalOperator, ResponseOperator, plotting,\
create_power_operator create_power_operator, nifty_configuration
from nifty.library import WienerFilterCurvature from nifty.library import WienerFilterCurvature
if __name__ == "__main__": if __name__ == "__main__":
distribution_strategy = 'not' nifty_configuration['default_distribution_strategy'] = 'fftw'
nifty_configuration['harmonic_rg_base'] = 'real'
# Setting up variable parameters # Setting up variable parameters
...@@ -36,19 +37,17 @@ if __name__ == "__main__": ...@@ -36,19 +37,17 @@ if __name__ == "__main__":
signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels) signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
harmonic_space = FFTOperator.get_default_codomain(signal_space) harmonic_space = FFTOperator.get_default_codomain(signal_space)
fft = FFTOperator(harmonic_space, target=signal_space, fft = FFTOperator(harmonic_space, target=signal_space)
domain_dtype=np.complex, target_dtype=np.float) power_space = PowerSpace(harmonic_space)
power_space = PowerSpace(harmonic_space,
distribution_strategy=distribution_strategy)
# Creating the mock data # Creating the mock data
S = create_power_operator(harmonic_space, power_spectrum=power_spectrum, S = create_power_operator(harmonic_space, power_spectrum=power_spectrum)
distribution_strategy=distribution_strategy)
mock_power = Field(power_space, val=power_spectrum, mock_power = Field(power_space, val=power_spectrum)
distribution_strategy=distribution_strategy)
np.random.seed(43) np.random.seed(43)
mock_harmonic = mock_power.power_synthesize(real_signal=True) mock_harmonic = mock_power.power_synthesize(real_signal=True)
if nifty_configuration['harmonic_rg_base'] == 'real':
mock_harmonic = mock_harmonic.real
mock_signal = fft(mock_harmonic) mock_signal = fft(mock_harmonic)
R = ResponseOperator(signal_space, sigma=(response_sigma,)) R = ResponseOperator(signal_space, sigma=(response_sigma,))
...@@ -74,9 +73,10 @@ if __name__ == "__main__": ...@@ -74,9 +73,10 @@ if __name__ == "__main__":
plotter = plotting.RG2DPlotter() plotter = plotting.RG2DPlotter()
plotter.path = 'mock_signal.html' plotter.path = 'mock_signal.html'
plotter(mock_signal) plotter(mock_signal.real)
plotter.path = 'data.html' plotter.path = 'data.html'
plotter(Field(signal_space, plotter(Field(
val=data.val.get_full_data().reshape(signal_space.shape))) signal_space,
val=data.val.get_full_data().real.reshape(signal_space.shape)))
plotter.path = 'map.html' plotter.path = 'map.html'
plotter(m_s) plotter(m_s.real)
import numpy as np
from nifty import RGSpace, PowerSpace, Field, RealFFTOperator,\
ComposedOperator, DiagonalOperator, ResponseOperator,\
plotting, create_power_operator
from nifty.library import WienerFilterCurvature
if __name__ == "__main__":
distribution_strategy = 'not'
# Setting up variable parameters
# Typical distance over which the field is correlated
correlation_length = 0.05
# Variance of field in position space sqrt(<|s_x|^2>)
field_variance = 2.
# smoothing length of response (in same unit as L)
response_sigma = 0.01
# The signal to noise ratio
signal_to_noise = 0.7
# note that field_variance**2 = a*k_0/4. for this analytic form of power
# spectrum
def power_spectrum(k):
a = 4 * correlation_length * field_variance**2
return a / (1 + k * correlation_length) ** 4
# Setting up the geometry
# Total side-length of the domain
L = 2.
# Grid resolution (pixels per axis)
N_pixels = 512
signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
harmonic_space = RealFFTOperator.get_default_codomain(signal_space)
fft = RealFFTOperator(harmonic_space, target=signal_space)
power_space = PowerSpace(harmonic_space,
distribution_strategy=distribution_strategy)
# Creating the mock data
S = create_power_operator(harmonic_space, power_spectrum=power_spectrum,
distribution_strategy=distribution_strategy)
mock_power = Field(power_space, val=power_spectrum,
distribution_strategy=distribution_strategy)
np.random.seed(43)
mock_harmonic = mock_power.power_synthesize(real_signal=True)
mock_harmonic = mock_harmonic.real + mock_harmonic.imag
mock_signal = fft(mock_harmonic)
R = ResponseOperator(signal_space, sigma=(response_sigma,))
data_domain = R.target[0]
R_harmonic = ComposedOperator([fft, R], default_spaces=[0, 0])
N = DiagonalOperator(data_domain,
diagonal=mock_signal.var()/signal_to_noise,
bare=True)
noise = Field.from_random(domain=data_domain,
random_type='normal',
std=mock_signal.std()/np.sqrt(signal_to_noise),
mean=0)
data = R(mock_signal) + noise
# Wiener filter
j = R_harmonic.adjoint_times(N.inverse_times(data))
wiener_curvature = WienerFilterCurvature(S=S, N=N, R=R_harmonic)
m = wiener_curvature.inverse_times(j)
m_s = fft(m)
plotter = plotting.RG2DPlotter()
plotter.path = 'mock_signal.html'
plotter(mock_signal)
plotter.path = 'data.html'
plotter(Field(signal_space,
val=data.val.get_full_data().reshape(signal_space.shape)))
plotter.path = 'map.html'
plotter(m_s)
...@@ -70,11 +70,18 @@ variable_default_distribution_strategy = keepers.Variable( ...@@ -70,11 +70,18 @@ variable_default_distribution_strategy = keepers.Variable(
if z == 'fftw' else True), if z == 'fftw' else True),
genus='str') genus='str')
variable_harmonic_rg_base = keepers.Variable(
'harmonic_rg_base',
['real', 'complex'],
lambda z: z in ['real', 'complex'],
genus='str')
nifty_configuration = keepers.get_Configuration( nifty_configuration = keepers.get_Configuration(
name='NIFTy', name='NIFTy',
variables=[variable_fft_module, variables=[variable_fft_module,
variable_default_field_dtype, variable_default_field_dtype,
variable_default_distribution_strategy], variable_default_distribution_strategy,
variable_harmonic_rg_base],
file_name='NIFTy.conf', file_name='NIFTy.conf',
search_paths=[os.path.expanduser('~') + "/.config/nifty/", search_paths=[os.path.expanduser('~') + "/.config/nifty/",
os.path.expanduser('~') + "/.config/", os.path.expanduser('~') + "/.config/",
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
from __future__ import division from __future__ import division
import ast import ast
import itertools
import numpy as np import numpy as np
from keepers import Versionable,\ from keepers import Versionable,\
......
...@@ -18,4 +18,3 @@ ...@@ -18,4 +18,3 @@
from transformations import * from transformations import *
from .fft_operator import FFTOperator from .fft_operator import FFTOperator
from .real_fft_operator import RealFFTOperator
...@@ -142,17 +142,11 @@ class FFTOperator(LinearOperator): ...@@ -142,17 +142,11 @@ class FFTOperator(LinearOperator):
backward_class, self.target[0], self.domain[0], module=module) backward_class, self.target[0], self.domain[0], module=module)
# Store the dtype information # Store the dtype information
if domain_dtype is None: self.domain_dtype = \
self.logger.info("Setting domain_dtype to np.complex.") None if domain_dtype is None else np.dtype(domain_dtype)
self.domain_dtype = np.complex
else:
self.domain_dtype = np.dtype(domain_dtype)
if target_dtype is None: self.target_dtype = \
self.logger.info("Setting target_dtype to np.complex.") None if target_dtype is None else np.dtype(target_dtype)
self.target_dtype = np.complex
else:
self.target_dtype = np.dtype(target_dtype)
def _times(self, x, spaces): def _times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
...@@ -172,8 +166,10 @@ class FFTOperator(LinearOperator): ...@@ -172,8 +166,10 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain) result_domain = list(x.domain)
result_domain[spaces[0]] = self.target[0] result_domain[spaces[0]] = self.target[0]
result_dtype = \
new_val.dtype if self.target_dtype is None else self.target_dtype
result_field = x.copy_empty(domain=result_domain, result_field = x.copy_empty(domain=result_domain,
dtype=self.target_dtype) dtype=result_dtype)
result_field.set_val(new_val=new_val, copy=True) result_field.set_val(new_val=new_val, copy=True)
return result_field return result_field
...@@ -196,8 +192,11 @@ class FFTOperator(LinearOperator): ...@@ -196,8 +192,11 @@ class FFTOperator(LinearOperator):
result_domain = list(x.domain) result_domain = list(x.domain)
result_domain[spaces[0]] = self.domain[0] result_domain[spaces[0]] = self.domain[0]
result_dtype = \
new_val.dtype if self.domain_dtype is None else self.domain_dtype
result_field = x.copy_empty(domain=result_domain, result_field = x.copy_empty(domain=result_domain,
dtype=self.domain_dtype) dtype=result_dtype)
result_field.set_val(new_val=new_val, copy=True) result_field.set_val(new_val=new_val, copy=True)
return result_field return result_field
......
This diff is collapsed.
...@@ -260,7 +260,7 @@ class MPIFFT(Transform): ...@@ -260,7 +260,7 @@ class MPIFFT(Transform):
p() p()
if p.has_output: if p.has_output:
result = p.output_array result = p.output_array.copy()
if result.shape != val.shape: if result.shape != val.shape:
raise ValueError("Output shape is different than input shape. " raise ValueError("Output shape is different than input shape. "
"Maybe fftw tries to optimize the " "Maybe fftw tries to optimize the "
......
...@@ -43,6 +43,8 @@ class RGRGTransformation(Transformation): ...@@ -43,6 +43,8 @@ class RGRGTransformation(Transformation):
else: else:
raise ValueError('Unsupported FFT module:' + module) raise ValueError('Unsupported FFT module:' + module)
self.harmonic_base = nifty_configuration['harmonic_rg_base']
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property @property
...@@ -144,7 +146,31 @@ class RGRGTransformation(Transformation): ...@@ -144,7 +146,31 @@ class RGRGTransformation(Transformation):
val = self._transform.domain.weight(val, power=1, axes=axes) val = self._transform.domain.weight(val, power=1, axes=axes)
# Perform the transformation # Perform the transformation
Tval = self._transform.transform(val, axes, **kwargs) if self.harmonic_base == 'complex':
Tval = self._transform.transform(val, axes, **kwargs)
else:
if issubclass(val.dtype.type, np.complexfloating):
Tval_real = self._transform.transform(val.real, axes,
**kwargs)
Tval_imag = self._transform.transform(val.imag, axes,
**kwargs)
if self.codomain.harmonic:
Tval_real.data.real += Tval_real.data.imag
Tval_real.data.imag = \
Tval_imag.data.real + Tval_imag.data.imag
else:
Tval_real.data.real -= Tval_real.data.imag
Tval_real.data.imag = \
Tval_imag.data.real - Tval_imag.data.imag
Tval = Tval_real
else:
Tval = self._transform.transform(val, axes, **kwargs)
if self.codomain.harmonic:
Tval.data.real += Tval.data.imag
else:
Tval.data.real -= Tval.data.imag
Tval = Tval.real
if not self._transform.codomain.harmonic: if not self._transform.codomain.harmonic:
# correct for inverse fft. # correct for inverse fft.
......
...@@ -36,7 +36,7 @@ from d2o import distributed_data_object,\ ...@@ -36,7 +36,7 @@ from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.spaces.space import Space from nifty.spaces.space import Space
from nifty.config import nifty_configuration
class RGSpace(Space): class RGSpace(Space):
""" """
...@@ -122,33 +122,36 @@ class RGSpace(Space): ...@@ -122,33 +122,36 @@ class RGSpace(Space):
# return fixed_points # return fixed_points
def hermitianize_inverter(self, x, axes): def hermitianize_inverter(self, x, axes):
# calculate the number of dimensions the input array has if nifty_configuration['harmonic_rg_base'] == 'real':
dimensions = len(x.shape) return x
# prepare the slicing object which will be used for mirroring else:
slice_primitive = [slice(None), ] * dimensions # calculate the number of dimensions the input array has
# copy the input data dimensions = len(x.shape)
y = x.copy() # prepare the slicing object which will be used for mirroring
slice_primitive = [slice(None), ] * dimensions
# flip in the desired directions # copy the input data
for k in range(len(axes)): y = x.copy()
i = axes[k]
slice_picker = slice_primitive[:] # flip in the desired directions
slice_inverter = slice_primitive[:] for k in range(len(axes)):
if (not self.zerocenter[k]) or self.shape[k] % 2 == 0: i = axes[k]
slice_picker[i] = slice(1, None, None) slice_picker = slice_primitive[:]
slice_inverter[i] = slice(None, 0, -1) slice_inverter = slice_primitive[:]
else: if (not self.zerocenter[k]) or self.shape[k] % 2 == 0:
slice_picker[i] = slice(None) slice_picker[i] = slice(1, None, None)
slice_inverter[i] = slice(None, None, -1) slice_inverter[i] = slice(None, 0, -1)
slice_picker = tuple(slice_picker) else:
slice_inverter = tuple(slice_inverter) slice_picker[i] = slice(None)
slice_inverter[i] = slice(None, None, -1)
try: slice_picker = tuple(slice_picker)
y.set_data(to_key=slice_picker, data=y, slice_inverter = tuple(slice_inverter)
from_key=slice_inverter)
except(AttributeError): try:
y[slice_picker] = y[slice_inverter] y.set_data(to_key=slice_picker, data=y,
return y from_key=slice_inverter)
except(AttributeError):
y[slice_picker] = y[slice_inverter]
return y
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from nifty import Space,\ from nifty import Space,\
PowerSpace,\ PowerSpace,\
Field,\ Field,\
...@@ -71,9 +73,10 @@ def create_power_operator(domain, power_spectrum, dtype=None, ...@@ -71,9 +73,10 @@ def create_power_operator(domain, power_spectrum, dtype=None,
distribution_strategy='not') distribution_strategy='not')
f = fp.power_synthesize(mean=1, std=0, real_signal=False, f = fp.power_synthesize(mean=1, std=0, real_signal=False,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# MR FIXME: we need the real part here. Could this also be achieved
# by setting real_signal=True in the call above? if not issubclass(fp.dtype.type, np.complexfloating):
f = f.real f = f.real
f **= 2 f **= 2
return DiagonalOperator(domain, diagonal=f, bare=True) return DiagonalOperator(domain, diagonal=f, bare=True)
......
...@@ -31,14 +31,7 @@ from itertools import product ...@@ -31,14 +31,7 @@ from itertools import product
from test.common import expand from test.common import expand
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from d2o import STRATEGIES
def _harmonic_type(itp):
otp = itp
if otp == np.float64:
otp = np.complex128
elif otp == np.float32:
otp = np.complex64
return otp
def _get_rtol(tp): def _get_rtol(tp):
...@@ -49,100 +42,119 @@ def _get_rtol(tp): ...@@ -49,100 +42,119 @@ def _get_rtol(tp):
class FFTOperatorTests(unittest.TestCase): class FFTOperatorTests(unittest.TestCase):
@expand(product([10, 11], [False, True], [0.1, 1, 3.7]))
def test_RG_distance_1D(self, dim1, zc1, d):
foo = RGSpace([dim1], zerocenter=zc1, distances=d)
res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2)], 0.)
@expand(product([10, 11], [9, 28], [False, True], [False, True],
[0.1, 1, 3.7]))
def test_RG_distance_2D(self, dim1, dim2, zc1, zc2, d):
foo = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=d)
res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["numpy", "fftw", "fftw_mpi"], @expand(product(["numpy", "fftw", "fftw_mpi"],
[16, ], [False, True], [False, True], [16, ], [0.1, 1, 3.7], STRATEGIES['global'],
[0.1, 1, 3.7], [np.float64, np.float32, np.complex64, np.complex128],
[np.float64, np.complex128, np.float32, np.complex64])) ['real', 'complex']))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp): def test_fft1D(self, module, dim1, d, distribution_strategy, itp, base):
if module == "fftw_mpi": if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'): if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest raise SkipTest
if module == "fftw" and "fftw" not in gdi: if module == "fftw" and "fftw" not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d) a = RGSpace(dim1, distances=d)
b = RGSpace(dim1, zerocenter=zc2, distances=1./(dim1*d), harmonic=True) b = RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
fft = FFTOperator(domain=a, target=b, domain_dtype=itp, fft = FFTOperator(domain=a, target=b, module=module)
target_dtype=_harmonic_type(itp), module=module) fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
np.random.seed(16) np.random.seed(16)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3, inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=itp) dtype=itp,
distribution_strategy=distribution_strategy)
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val.get_full_data(), assert_allclose(inp.val.get_full_data(),
out.val.get_full_data(), out.val.get_full_data(),
rtol=tol, atol=tol) rtol=tol, atol=tol)
@expand(product(["numpy", "fftw", "fftw_mpi"], @expand(product(["numpy", "fftw", "fftw_mpi"],
[12, 15], [9, 12], [False, True], [12, 15], [9, 12], [0.1, 1, 3.7],
[False, True], [False, True], [False, True], [0.1, 1, 3.7], [0.4, 1, 2.7], STRATEGIES['global'],
[0.4, 1, 2.7], [np.float64, np.float32, np.complex64, np.complex128],
[np.float64, np.complex128, np.float32, np.complex64])) ['real', 'complex']))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp): def test_fft2D(self, module, dim1, dim2, d1, d2, distribution_strategy,
itp, base):
if module == "fftw_mpi": if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'): if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest raise SkipTest
if module == "fftw" and "fftw" not in gdi: if module == "fftw" and "fftw" not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(itp) tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2]) a = RGSpace([dim1, dim2], distances=[d1, d2])
b = RGSpace([dim1, dim2], zerocenter=[zc3, zc4], b = RGSpace([dim1, dim2],
distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True) distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
fft = FFTOperator(domain=a, target=b, domain_dtype=itp, fft = FFTOperator(domain=a, target=b, module=module)
target_dtype=_harmonic_type(itp), module=module) fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3, inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=itp) dtype=itp,
distribution_strategy=distribution_strategy)
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw", "fftw_mpi"],
[0, 1, 2],
STRATEGIES['global'],
[np.float64, np.float32, np.complex64, np.complex128],
['real', 'complex']))
def test_composed_fft(self, module, index, distribution_strategy, dtype,
base):
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(dtype)
a = [a1, a2, a3] = [RGSpace((8,)), RGSpace((4, 4)), RGSpace((5, 6))]
fft = FFTOperator(domain=a[index], module=module,
default_spaces=(index,))
fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
inp = Field.from_random(domain=(a1, a2, a3), random_type='normal',
std=7, mean=3, dtype=dtype,
distribution_strategy=distribution_strategy)
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol) assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product([0, 3, 6, 11, 30], @expand(product([0, 3, 6, 11, 30],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.float32, np.complex64, np.complex128]))
def test_sht(self, lm, tp): def test_sht(self, lm, tp):
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(tp) tol = _get_rtol(tp)
a = LMSpace(lmax=lm) a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1) b = GLSpace(nlat=lm+1)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp) fft = FFTOperator(domain=a, target=b)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3, inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=tp) dtype=tp)
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol) assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product([128, 256], @expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.float32, np.complex64, np.complex128]))
def test_sht2(self, lm, tp): def test_sht2(self, lm, tp):
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise SkipTest raise SkipTest
a = LMSpace(lmax=lm) a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2) b = HPSpace(nside=lm//2)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp) fft = FFTOperator(domain=a, target=b)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0, inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp) dtype=tp)
out = fft.adjoint_times(fft.times(inp)) out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1) assert_allclose(inp.val, out.val, rtol=1e-3, atol=1e-1)
@expand(product([128, 256], @expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.float32, np.complex64, np.complex128]))
def test_dotsht(self, lm, tp): def test_dotsht(self, lm, tp):
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise SkipTest raise SkipTest
tol = _get_rtol(tp) tol = _get_rtol(tp)
a = LMSpace(lmax=lm) a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1) b = GLSpace(nlat=lm+1)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp) fft = FFTOperator(domain=a, target=b)
inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0, inp = Field.from_random(domain=a, random_type='normal', std=1, mean=0,
dtype=tp) dtype=tp)
out = fft.times(inp) out = fft.times(inp)
...@@ -151,14 +163,14 @@ class FFTOperatorTests(unittest.TestCase): ...@@ -151,14 +163,14 @@ class FFTOperatorTests(unittest.TestCase):
assert_allclose(v1, v2, rtol=tol, atol=tol) assert_allclose(v1, v2, rtol=tol, atol=tol)
@expand(product([128, 256], @expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64])) [np.float64, np.float32, np.complex64, np.complex128]))
def test_dotsht2(self, lm, tp): def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise SkipTest