Commit 69fa111d authored by Jait Dixit's avatar Jait Dixit
Browse files

Fix issue #49

parent 95d48149
......@@ -7,7 +7,7 @@ import itertools
from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator
from nifty.transformations.transformation import Transformation
import nifty.transformations.transformation as transformation
from nifty.rg.rg_space import gc as RG_GC
import d2o
......@@ -36,49 +36,34 @@ def weighted_np_transform(val, domain, codomain, axes=None):
###############################################################################
rg_fft_modules = []
rg_rg_fft_modules = []
for name in ['gfft', 'gfft_dummy', 'pyfftw']:
if RG_GC.validQ('fft_module', name):
rg_fft_modules += [name]
rg_test_shapes = [(128, 128), (179, 179), (512, 512)]
rg_test_data = np.array(
[[0.38405405 + 0.32460996j, 0.02718878 + 0.08326207j,
0.78792080 + 0.81192595j, 0.17535687 + 0.68054781j,
0.93044845 + 0.71942995j, 0.21179999 + 0.00637665j],
[0.10905553 + 0.3027462j, 0.37361237 + 0.68434316j,
0.94070232 + 0.34129582j, 0.04658034 + 0.4575192j,
0.45057929 + 0.64297612j, 0.01007361 + 0.24953504j],
[0.39579662 + 0.70881906j, 0.01614435 + 0.82603832j,
0.84036344 + 0.50321592j, 0.87699553 + 0.40337862j,
0.11816016 + 0.43332373j, 0.76627757 + 0.66327959j],
[0.77272335 + 0.18277367j, 0.93341953 + 0.58105518j,
0.27227913 + 0.17458168j, 0.70204032 + 0.81397425j,
0.12422993 + 0.19215286j, 0.30897158 + 0.47364969j],
[0.24702012 + 0.54534373j, 0.55206013 + 0.98406613j,
0.57408167 + 0.55685406j, 0.87991341 + 0.52534323j,
0.93912604 + 0.97186519j, 0.77778942 + 0.45812051j],
[0.79367868 + 0.48149411j, 0.42484378 + 0.74870011j,
0.79611264 + 0.50926774j, 0.35372794 + 0.10468412j,
0.46140736 + 0.09449825j, 0.82044644 + 0.95992843j]])
rg_rg_fft_modules += [name]
rg_rg_test_shapes = [(128, 128), (179, 179), (512, 512)]
rg_rg_test_spaces = [(GLSpace(8),), (HPSpace(8),), (LMSpace(8),)]
gl_hp_lm_test_spaces = [(GLSpace(8),), (HPSpace(8),), (RGSpace(8),)]
lm_gl_hp_test_spaces = [(LMSpace(8),), (RGSpace(8),)]
###############################################################################
class TestRGSpaceTransforms(unittest.TestCase):
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_check_codomain_none(self, module):
class TestRGRGTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = RGSpace((8, 8))
with assert_raises(ValueError):
transformator.create(x, None, module=module)
transformator.create(x, None)
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_check_codomain_mismatch(self, module):
@parameterized.expand(
rg_rg_test_spaces,
testcase_func_name=custom_name_func
)
def test_check_codomain_mismatch(self, space):
x = RGSpace((8, 8))
y = LMSpace(8)
with assert_raises(TypeError):
transformator.create(x, y, module=module)
transformator.create(x, space)
@parameterized.expand(
itertools.product([0, 1, 2], [None, (1, 1), (10, 10)], [False, True]),
......@@ -87,10 +72,10 @@ class TestRGSpaceTransforms(unittest.TestCase):
def test_check_codomain_rgspecific(self, complexity, distances, harmonic):
x = RGSpace((8, 8), complexity=complexity,
distances=distances, harmonic=harmonic)
assert (Transformation.check_codomain(x, x.get_codomain()))
assert (Transformation.check_codomain(x, x.get_codomain()))
assert (transformation.RGRGTransformation.check_codomain(x, x.get_codomain()))
assert (transformation.RGRGTransformation.check_codomain(x, x.get_codomain()))
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
@parameterized.expand(rg_rg_fft_modules, testcase_func_name=custom_name_func)
def test_shapemismatch(self, module):
x = RGSpace((8, 8))
b = d2o.distributed_data_object(np.ones((8, 8)))
......@@ -100,7 +85,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
).transform(b, axes=(0, 1, 2))
@parameterized.expand(
itertools.product(rg_fft_modules, rg_test_shapes),
itertools.product(rg_rg_fft_modules, rg_rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_local_ndarray(self, module, shape):
......@@ -114,7 +99,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
)
@parameterized.expand(
itertools.product(rg_fft_modules, rg_test_shapes),
itertools.product(rg_rg_fft_modules, rg_rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_local_notzero(self, module, shape):
......@@ -129,7 +114,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
)
@parameterized.expand(
itertools.product(rg_fft_modules, rg_test_shapes),
itertools.product(rg_rg_fft_modules, rg_rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_not(self, module, shape):
......@@ -144,7 +129,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
)
@parameterized.expand(
itertools.product(rg_test_shapes),
itertools.product(rg_rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_mpi_axesnone(self, shape):
......@@ -159,7 +144,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
)
@parameterized.expand(
itertools.product(rg_test_shapes),
itertools.product(rg_rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_mpi_axesnone_equal(self, shape):
......@@ -173,38 +158,54 @@ class TestRGSpaceTransforms(unittest.TestCase):
weighted_np_transform(a, x, x.get_codomain())
)
class TestGLLMTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = GLSpace(8)
with assert_raises(ValueError):
transformator.create(x, None)
@parameterized.expand(
itertools.product(rg_fft_modules),
testcase_func_name=custom_name_func)
def test_calc_transform_explicit(self, module):
data = rg_test_data.copy()
d2o_data = d2o.distributed_data_object(data)
shape = data.shape
x = RGSpace(shape, complexity=2, zerocenter=False)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
).transform(d2o_data),
np.array([[0.50541615 + 0.50558267j, -0.01458536 - 0.01646137j,
0.01649006 + 0.01990988j, 0.04668049 - 0.03351745j,
-0.04382765 - 0.06455639j, -0.05978564 + 0.01334044j],
[-0.05347464 + 0.04233343j, -0.05167177 + 0.00643947j,
-0.01995970 - 0.01168872j, 0.10653817 + 0.03885947j,
-0.03298075 - 0.00374715j, 0.00622585 - 0.01037453j],
[-0.01128964 - 0.02424692j, -0.03347793 - 0.0358814j,
-0.03924164 - 0.01978305j, 0.03821242 - 0.00435542j,
0.07533170 + 0.14590143j, -0.01493027 - 0.02664675j],
[0.02238926 + 0.06140625j, -0.06211313 + 0.03317753j,
0.01519073 + 0.02842563j, 0.00517758 + 0.08601604j,
-0.02246912 - 0.01942764j, -0.06627311 - 0.08763801j],
[-0.02492378 - 0.06097411j, 0.06365649 - 0.09346585j,
0.05031486 + 0.00858656j, -0.00881969 + 0.01842357j,
-0.01972641 - 0.00994365j, 0.05289453 - 0.06822038j],
[-0.01865586 - 0.08640926j, 0.03414096 - 0.02605602j,
-0.09492552 + 0.01306734j, 0.09355730 + 0.07553701j,
-0.02395259 - 0.02185743j, -0.03107832 - 0.04714527j]])
)
gl_hp_lm_test_spaces,
testcase_func_name=custom_name_func
)
def test_check_codomain_mismatch(self, space):
x = GLSpace(8)
with assert_raises(TypeError):
transformator.create(x, space)
class TestHPLMTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = HPSpace(8)
with assert_raises(ValueError):
transformator.create(x, None)
@parameterized.expand(
gl_hp_lm_test_spaces,
testcase_func_name=custom_name_func
)
def test_check_codomain_mismatch(self, space):
x = GLSpace(8)
with assert_raises(TypeError):
transformator.create(x, space)
class TestLMTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = LMSpace(8)
with assert_raises(ValueError):
transformator.create(x, None)
@parameterized.expand(
lm_gl_hp_test_spaces,
testcase_func_name=custom_name_func
)
def test_check_codomain_mismatch(self, space):
x = LMSpace(8)
with assert_raises(ValueError):
transformator.create(x, space)
if __name__ == '__main__':
unittest.main()
import numpy as np
from fftw import FFTW
from gfft import GFFT
from gltransform import GLTransform
......@@ -6,8 +7,6 @@ from lmtransform import LMTransform
from nifty.config import about, dependency_injector as gdi
from nifty import RGSpace, GLSpace, HPSpace, LMSpace
import numpy as np
class Transformation(object):
"""
......@@ -15,113 +14,6 @@ class Transformation(object):
method for all transforms.
"""
@staticmethod
def check_codomain(domain, codomain):
if codomain is None:
return False
if isinstance(domain, RGSpace):
if not isinstance(codomain, RGSpace):
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a RGSpace."
))
if not np.all(np.array(domain.paradict['shape']) ==
np.array(codomain.paradict['shape'])):
return False
if domain.harmonic == codomain.harmonic:
return False
# Check complexity
# Prepare shorthands
dcomp = domain.paradict['complexity']
cocomp = codomain.paradict['complexity']
# Case 1: if domain is completely complex, the codomain
# must be complex too
if dcomp == 2:
if cocomp != 2:
return False
# Case 2: if domain is hermitian, the codomain can be
# real, a warning is raised otherwise
elif dcomp == 1:
if cocomp > 0:
about.warnings.cprint(
"WARNING: Unrecommended codomain! " +
"The domain is hermitian, hence the" +
"codomain should be restricted to real values."
)
# Case 3: if domain is real, the codomain should be hermitian
elif dcomp == 0:
if cocomp == 2:
about.warnings.cprint(
"WARNING: Unrecommended codomain! " +
"The domain is real, hence the" +
"codomain should be restricted to" +
"hermitian configuration."
)
elif cocomp == 0:
return False
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(
np.absolute(np.array(domain.paradict['shape']) *
np.array(domain.distances) *
np.array(
codomain.distances) - 1) < domain.epsilon):
return False
elif isinstance(domain, GLSpace):
if not isinstance(codomain, LMSpace):
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."
))
# shorthand
nlat = domain.paradict['nlat']
nlon = domain.paradict['nlon']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False
elif isinstance(domain, HPSpace):
if not isinstance(codomain, LMSpace):
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."
))
nside = domain.paradict['nside']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (3 * nside - 1 != lmax) or (lmax != mmax):
return False
elif isinstance(domain, LMSpace):
if isinstance(codomain, GLSpace):
nlat = codomain.paradict['nlat']
nlon = codomain.paradict['nlon']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (nlat != lmax + 1) or \
(nlon != 2 * lmax + 1):
return False
elif isinstance(codomain, HPSpace):
nside = codomain.paradict['nside']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (3 * nside - 1 != lmax):
return False
else:
raise ValueError('ERROT: LMSpace codomain should be ' +
'HPSpace or GLSpace')
else:
return False
return True
def __init__(self, domain, codomain, module=None):
pass
......@@ -131,7 +23,7 @@ class Transformation(object):
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if Transformation.check_codomain(domain, codomain):
if self.check_codomain(domain, codomain):
if module is None:
if gdi.get('pyfftw') is None:
if gdi.get('gfft') is None:
......@@ -162,6 +54,64 @@ class RGRGTransformation(Transformation):
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, RGSpace):
raise TypeError('ERROR: domain must be a RGSpace')
if codomain is None:
return False
if not isinstance(codomain, RGSpace):
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a RGSpace."
))
if not np.all(np.array(domain.paradict['shape']) ==
np.array(codomain.paradict['shape'])):
return False
if domain.harmonic == codomain.harmonic:
return False
# check complexity
dcomp = domain.paradict['complexity']
cocomp = codomain.paradict['complexity']
# Case 1: if domain is completely complex, the codomain
# must be complex too
if dcomp == 2:
if cocomp != 2:
return False
# Case 2: if domain is hermitian, the codomain can be
# real, a warning is raised otherwise
elif dcomp == 1:
if cocomp > 0:
about.warnings.cprint(
"WARNING: Unrecommended codomain! " +
"The domain is hermitian, hence the" +
"codomain should be restricted to real values."
)
# Case 3: if domain is real, the codomain should be hermitian
elif dcomp == 0:
if cocomp == 2:
about.warnings.cprint(
"WARNING: Unrecommended codomain! " +
"The domain is real, hence the" +
"codomain should be restricted to" +
"hermitian configuration."
)
elif cocomp == 0:
return False
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(np.absolute(np.array(domain.paradict['shape']) *
np.array(domain.distances) *
np.array(codomain.distances) - 1) < domain.epsilon):
return False
return True
def transform(self, val, axes=None, **kwargs):
if self._transform.codomain.harmonic:
# correct for forward fft
......@@ -178,29 +128,71 @@ class RGRGTransformation(Transformation):
class GLLMTransformation(Transformation):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
def __init__(self, domain, codomain, module=None):
if self.check_codomain(domain, codomain):
self._transform = GLTransform(domain, codomain)
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain is not a GLSpace')
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nlat = domain.paradict['nlat']
nlon = domain.paradict['nlon']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
class HPLMTransformation(Transformation):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
def __init__(self, domain, codomain, module=None):
if self.check_codomain(domain, codomain):
self._transform = HPTransform(domain, codomain)
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, HPSpace):
raise TypeError('ERROR: domain is not a HPSpace')
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nside = domain.paradict['nside']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (3 * nside - 1 != lmax) or (lmax != mmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
class LMGLTransformation(Transformation):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
def __init__(self, domain, codomain, module=None):
if self.check_codomain(domain, codomain):
if gdi.get('libsharp_wrapper_gl') is None:
raise ImportError("The module libsharp is " +
"needed but not available.")
......@@ -210,12 +202,34 @@ class LMGLTransformation(Transformation):
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, GLSpace):
raise TypeError('ERROR: codomain must be a GLSpace.')
nlat = codomain.paradict['nlat']
nlon = codomain.paradict['nlon']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1):
return False
return True
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
class LMHPTransformation(Transformation):
def __init__(self, domain, codomain):
if Transformation.check_codomain(domain, codomain):
def __init__(self, domain, codomain, module=None):
if self.check_codomain(domain, codomain):
if gdi.get('healpy') is None:
raise ImportError("The module healpy is needed" +
"but not available.")
......@@ -225,5 +239,24 @@ class LMHPTransformation(Transformation):
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain is not a LMSpace')
if codomain is None:
return False
if not isinstance(codomain, HPSpace):
raise TypeError('ERROR: codomain must be a HPSpace.')
nside = codomain.paradict['nside']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
if (lmax != mmax) or (3 * nside - 1 != lmax):
return False
return True
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
\ No newline at end of file
return self._transform.transform(val, axes, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment