Commit 726e82bb authored by Jait Dixit's avatar Jait Dixit
Browse files

Add HLTransform and GLTransform to available transformations

parent 2057a238
......@@ -21,6 +21,21 @@ def custom_name_func(testcase_func, param_num, param):
)
def check_equality(space, data1, data2):
return space.unary_operation(space.binary_operation(data1, data2, 'eq'),
'all')
def check_almost_equality(space, data1, data2, integers=7):
return space.unary_operation(
space.binary_operation(
space.unary_operation(
space.binary_operation(data1, data2, 'sub'),
'abs'),
10. ** (-1. * integers), 'le'),
'all')
###############################################################################
rg_fft_modules = []
......@@ -28,7 +43,6 @@ for name in ['gfft', 'gfft_dummy', 'pyfftw']:
if RG_GC.validQ('fft_module', name):
rg_fft_modules += [name]
rg_test_shapes = [(128, 128), (179, 179), (512, 512)]
rg_test_data = np.array(
......@@ -51,6 +65,7 @@ rg_test_data = np.array(
0.79611264 + 0.50926774j, 0.35372794 + 0.10468412j,
0.46140736 + 0.09449825j, 0.82044644 + 0.95992843j]])
###############################################################################
class TestRGSpaceTransforms(unittest.TestCase):
......@@ -74,8 +89,8 @@ class TestRGSpaceTransforms(unittest.TestCase):
def test_check_codomain_rgspecific(self, complexity, distances, harmonic):
x = RGSpace((8, 8), complexity=complexity,
distances=distances, harmonic=harmonic)
assert(Transformation.check_codomain(x, x.get_codomain()))
assert(Transformation.check_codomain(x, x.get_codomain()))
assert (Transformation.check_codomain(x, x.get_codomain()))
assert (Transformation.check_codomain(x, x.get_codomain()))
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_shapemismatch(self, module):
......@@ -160,5 +175,6 @@ class TestRGSpaceTransforms(unittest.TestCase):
np.fft.fftn(a)
)
if __name__ == '__main__':
unittest.main()
......@@ -24,8 +24,6 @@ class FFTW(Transform):
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
self.name = 'pyfftw'
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
......
import numpy as np
from transform import Transform
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
gl = gdi.get('libsharp_wrapper_gl')
class GLTransform(Transform):
"""
GLTransform wrapper for libsharp's transform functions
"""
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
if 'libsharp_wrapper_gl' not in gdi:
raise ImportError("The module libsharp_wrapper_gl " +
"is needed but not available")
def transform(self, val, axes, **kwargs):
if self.domain.discrete:
val = self.calc_weight(val, power=-0.5)
# shorthands for transform parameters
nlat = self.domain.paradict['nlat']
nlon = self.domain.paradict['nlon']
lmax = self.codomain.paradict['lmax']
mmax = self.codomain.paradict['mmax']
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
return_val = None
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
if self.domain.dtype == np.dtype('float32'):
inp = gl.map2alm_f(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
else:
inp = gl.map2alm(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
\ No newline at end of file
import numpy as np
from transform import Transform
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
hp = gdi.get('healpy')
class HPTransform(Transform):
"""
GLTransform wrapper for libsharp's transform functions
"""
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
if 'healpy' not in gdi:
raise ImportError("The module healpy is needed but not available")
def transform(self, val, axes, **kwargs):
# get by number of iterations from kwargs
niter = kwargs['niter'] if 'niter' in kwargs else 0
if self.domain.discrete:
val = self.calc_weight(val, power=-0.5)
# shorthands for transform parameters
lmax = self.codomain.paradict['lmax']
mmax = self.codomain.paradict['mmax']
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
return_val = None
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
inp = hp.map2alm(inp.astype(np.float64, copy=False),
lmax=lmax, mmax=mmax, iter=niter, pol=True,
use_weights=False, datapath=None)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment