Commit 447d785c authored by Jait Dixit's avatar Jait Dixit
Browse files

Fix issue #54

parent e57fd73f
......@@ -577,7 +577,7 @@ class operator(object):
if domain is None:
domain = diag.domain
# weight if ...
if and bare:
if bare:
if(isinstance(diag, tuple)): # diag == (diag,variance)
return (diag[0].weight(power=-1),
diag[1].weight(power=-1))
......
......@@ -50,12 +50,6 @@ lm_gl_hp_test_spaces = [(LMSpace(8),), (RGSpace(8),)]
###############################################################################
class TestRGRGTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = RGSpace((8, 8))
with assert_raises(ValueError):
transformator.create(x, None)
@parameterized.expand(
rg_rg_test_spaces,
testcase_func_name=custom_name_func
......@@ -81,7 +75,7 @@ class TestRGRGTransformation(unittest.TestCase):
b = d2o.distributed_data_object(np.ones((8, 8)))
with assert_raises(ValueError):
transformator.create(
x, x.get_codomain(), module=module
x, module=module
).transform(b, axes=(0, 1, 2))
@parameterized.expand(
......@@ -93,7 +87,7 @@ class TestRGRGTransformation(unittest.TestCase):
a = np.ones(shape)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
x, module=module
).transform(a),
weighted_np_transform(a, x, x.get_codomain())
)
......@@ -108,7 +102,7 @@ class TestRGRGTransformation(unittest.TestCase):
b = d2o.distributed_data_object(a)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
x, module=module
).transform(b, axes=(1,)),
weighted_np_transform(a, x, x.get_codomain(), axes=(1,))
)
......@@ -123,7 +117,7 @@ class TestRGRGTransformation(unittest.TestCase):
b = d2o.distributed_data_object(a, distribution_strategy='not')
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
x, module=module
).transform(b),
weighted_np_transform(a, x, x.get_codomain())
)
......@@ -138,7 +132,7 @@ class TestRGRGTransformation(unittest.TestCase):
b = d2o.distributed_data_object(a)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module='pyfftw'
x, module='pyfftw'
).transform(b),
weighted_np_transform(a, x, x.get_codomain())
)
......@@ -153,18 +147,12 @@ class TestRGRGTransformation(unittest.TestCase):
b = d2o.distributed_data_object(a, distribution_strategy='equal')
assert np.allclose(
transformator.create(
x, x.get_codomain(), module='pyfftw'
x, module='pyfftw'
).transform(b),
weighted_np_transform(a, x, x.get_codomain())
)
class TestGLLMTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = GLSpace(8)
with assert_raises(ValueError):
transformator.create(x, None)
@parameterized.expand(
gl_hp_lm_test_spaces,
testcase_func_name=custom_name_func
......@@ -175,11 +163,6 @@ class TestGLLMTransformation(unittest.TestCase):
transformator.create(x, space)
class TestHPLMTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = HPSpace(8)
with assert_raises(ValueError):
transformator.create(x, None)
@parameterized.expand(
gl_hp_lm_test_spaces,
......@@ -191,12 +174,6 @@ class TestHPLMTransformation(unittest.TestCase):
transformator.create(x, space)
class TestLMTransformation(unittest.TestCase):
# all domain/codomain checks
def test_check_codomain_none(self):
x = LMSpace(8)
with assert_raises(ValueError):
transformator.create(x, None)
@parameterized.expand(
lm_gl_hp_test_spaces,
testcase_func_name=custom_name_func
......
......@@ -13,7 +13,10 @@ class GLLMTransformation(Transformation):
if 'libsharp_wrapper_gl' not in gdi:
raise ImportError("The module libsharp is needed but not available")
if self.check_codomain(domain, codomain):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
......
......@@ -13,7 +13,10 @@ class HPLMTransformation(Transformation):
if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available")
if self.check_codomain(domain, codomain):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
......
......@@ -14,7 +14,10 @@ class LMGLTransformation(Transformation):
raise ImportError(
"The module libsharp is needed but not available.")
if self.check_codomain(domain, codomain):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
......
......@@ -7,30 +7,35 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None):
if self.check_codomain(domain, codomain):
if module is None:
if nifty_configuration['fft_module'] == 'pyfftw':
self._transform = FFTW(domain, codomain)
elif nifty_configuration['fft_module'] == 'gfft' or \
nifty_configuration['fft_module'] == 'gfft_dummy':
self._transform = \
GFFT(domain,
codomain,
gdi.get(nifty_configuration['fft_module']))
else:
raise ValueError('ERROR: unknow default FFT module:' +
nifty_configuration['fft_module'])
if codomain is None:
codomain = self.get_codomain(domain)
else:
if not self.check_codomain(domain, codomain):
raise ValueError("ERROR: incompatible codomain!")
if module is None:
if nifty_configuration['fft_module'] == 'pyfftw':
self._transform = FFTW(domain, codomain)
elif nifty_configuration['fft_module'] == 'gfft' or \
nifty_configuration['fft_module'] == 'gfft_dummy':
self._transform = \
GFFT(domain,
codomain,
gdi.get(nifty_configuration['fft_module']))
else:
if module == 'pyfftw':
self._transform = FFTW(domain, codomain)
elif module == 'gfft':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft'))
elif module == 'gfft_dummy':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft_dummy'))
raise ValueError('ERROR: unknow default FFT module:' +
nifty_configuration['fft_module'])
else:
raise ValueError("ERROR: incompatible codomain!")
if module == 'pyfftw':
self._transform = FFTW(domain, codomain)
elif module == 'gfft':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft'))
elif module == 'gfft_dummy':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('ERROR: unknow FFT module:' + module)
@staticmethod
def get_codomain(domain, cozerocenter=None, **kwargs):
......
......@@ -34,7 +34,7 @@ class TransformationFactory(object):
else:
raise ValueError('ERROR: unknown domain')
def create(self, domain, codomain, module=None):
def create(self, domain, codomain=None, module=None):
key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
(179 * module.__hash__()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment