Commit 0a4c5b7f authored by theos's avatar theos
Browse files

Merge branch 'add_axes_to_rg_fft' into feature/field_multiple_space

# Conflicts:
#	test/test_nifty_spaces.py
parents 3dd360d1 c435e42f
Pipeline #5116 skipped
......@@ -33,22 +33,25 @@ def get_slice_list(shape, axes):
if not shape:
raise ValueError(about._errors.cstring("ERROR: shape cannot be None."))
if not all(axis < len(shape) for axis in axes):
raise ValueError(
about._errors.cstring("ERROR: axes(axis) does not match shape.")
)
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables = [range(y) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
yield slice_list
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError(
about._errors.cstring("ERROR: \
axes(axis) does not match shape.")
)
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\
[range(y) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
yield slice_list
else:
yield [slice(None, None)]
return
def hermitianize_gaussian(x):
......
This diff is collapsed.
......@@ -801,7 +801,7 @@ class rg_space(point_space):
result = np.asscalar(np.real(result))
return result
def calc_transform(self, x, codomain=None, **kwargs):
def calc_transform(self, x, codomain=None, axes=None, **kwargs):
"""
Computes the transform of a given array of field values.
......@@ -812,6 +812,8 @@ class rg_space(point_space):
codomain : nifty.rg_space, *optional*
codomain space to which the transformation shall map
(default: None).
axes : None or tuple
Axes in the array which should be transformed.
Returns
-------
......@@ -834,18 +836,12 @@ class rg_space(point_space):
# Perform the transformation
Tx = self.fft_machine.transform(val=x, domain=self, codomain=codomain,
**kwargs)
axes=axes, **kwargs)
if not codomain.harmonic:
# correct for inverse fft
Tx = codomain.calc_weight(Tx, power=-1)
# when the codomain space is purely real, the result of the
# transformation must be corrected accordingly. Using the casting
# method of codomain is sufficient
# TODO: Let .transform yield the correct dtype
Tx = codomain.cast(Tx)
return Tx
def calc_smooth(self, x, sigma=0, codomain=None):
......@@ -1649,4 +1645,4 @@ class rg_space(point_space):
def __repr__(self):
string = super(rg_space, self).__repr__()
string += repr(self.fft_machine) + "\n "
return string
\ No newline at end of file
return string
import nifty as nt
import numpy as np
import unittest
import d2o
class TestFFTWTransform(unittest.TestCase):
def test_comm(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a)
b.comm = [1, 2, 3] # change comm to something not supported
with self.assertRaises(RuntimeError):
x.fft_machine.transform(b, x, x.get_codomain())
def test_shapemismatch(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
with self.assertRaises(ValueError):
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1, 2))
def test_local_ndarray(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
self.assertTrue(
np.allclose(
x.fft_machine.transform(a, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_local_notzero(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(1,)),
np.fft.fftn(a, axes=(1,))
), 'results do not match numpy.fft.fftn'
)
def test_not(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='not')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesnone(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesnone_equal(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesall(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1)),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesall_equal(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1)),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero_equal(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero_not(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='not')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
if __name__ == '__main__':
unittest.main()
......@@ -742,33 +742,35 @@ class Test_RG_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(
itertools.product([True],
['pyfftw']),
testcase_func_name=custom_name_func)
def test_get_random_values(self, harmonic, ):
x = rg_space((4, 4), complexity=1, harmonic=harmonic)
# pm1
data = x.get_random_values(random='pm1')
flipped_data = flip(x, data)
assert(check_almost_equality(x, data, flipped_data))
# gau
data = x.get_random_values(random='gau', mean=4 + 3j, std=2)
flipped_data = flip(x, data)
assert(check_almost_equality(x, data, flipped_data))
# uni
data = x.get_random_values(random='uni', vmin=-2, vmax=4)
flipped_data = flip(x, data)
assert(check_almost_equality(x, data, flipped_data))
# syn
data = x.get_random_values(random='syn',
spec=lambda x: 42 / (1 + x)**3)
flipped_data = flip(x, data)
assert(check_almost_equality(x, data, flipped_data))
# @parameterized.expand(
# itertools.product([True], #[True, False],
# ['fftw']),
# #DATAMODELS['rg_space']),
# testcase_func_name=custom_name_func)
# def test_get_random_values(self, harmonic, datamodel):
# x = rg_space((4, 4), complexity=1, harmonic=harmonic,
# datamodel=datamodel)
#
# # pm1
# data = x.get_random_values(random='pm1')
# flipped_data = flip(x, data)
# assert(check_almost_equality(x, data, flipped_data))
#
# # gau
# data = x.get_random_values(random='gau', mean=4 + 3j, std=2)
# flipped_data = flip(x, data)
# assert(check_almost_equality(x, data, flipped_data))
#
# # uni
# data = x.get_random_values(random='uni', vmin=-2, vmax=4)
# flipped_data = flip(x, data)
# assert(check_almost_equality(x, data, flipped_data))
#
# # syn
# data = x.get_random_values(random='syn',
# spec=lambda x: 42 / (1 + x)**3)
# flipped_data = flip(x, data)
# assert(check_almost_equality(x, data, flipped_data))
###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment