Commit 08bc6e76 authored by Jait Dixit's avatar Jait Dixit
Browse files

Add HPLMTransformation and GLLMTransformation

parent 726e82bb
......@@ -20,21 +20,19 @@ def custom_name_func(testcase_func, param_num, param):
parameterized.to_safe_name("_".join(str(x) for x in param.args)),
)
def weighted_np_transform(val, domain, codomain, axes=None):
if codomain.harmonic:
# correct for forward fft
val = domain.calc_weight(val, power=1)
def check_equality(space, data1, data2):
return space.unary_operation(space.binary_operation(data1, data2, 'eq'),
'all')
# Perform the transformation
Tval = np.fft.fftn(val, axes=axes)
if not codomain.harmonic:
# correct for inverse fft
Tval = codomain.calc_weight(Tval, power=-1)
def check_almost_equality(space, data1, data2, integers=7):
return space.unary_operation(
space.binary_operation(
space.unary_operation(
space.binary_operation(data1, data2, 'sub'),
'abs'),
10. ** (-1. * integers), 'le'),
'all')
return Tval
###############################################################################
......@@ -112,7 +110,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
transformator.create(
x, x.get_codomain(), module=module
).transform(a),
np.fft.fftn(a)
weighted_np_transform(a, x, x.get_codomain())
)
@parameterized.expand(
......@@ -127,7 +125,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
transformator.create(
x, x.get_codomain(), module=module
).transform(b, axes=(1,)),
np.fft.fftn(a, axes=(1,))
weighted_np_transform(a, x, x.get_codomain(), axes=(1,))
)
@parameterized.expand(
......@@ -142,7 +140,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
transformator.create(
x, x.get_codomain(), module=module
).transform(b),
np.fft.fftn(a)
weighted_np_transform(a, x, x.get_codomain())
)
@parameterized.expand(
......@@ -157,7 +155,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
transformator.create(
x, x.get_codomain(), module='pyfftw'
).transform(b),
np.fft.fftn(a)
weighted_np_transform(a, x, x.get_codomain())
)
@parameterized.expand(
......@@ -172,9 +170,41 @@ class TestRGSpaceTransforms(unittest.TestCase):
transformator.create(
x, x.get_codomain(), module='pyfftw'
).transform(b),
np.fft.fftn(a)
weighted_np_transform(a, x, x.get_codomain())
)
@parameterized.expand(
itertools.product(rg_fft_modules),
testcase_func_name=custom_name_func)
def test_calc_transform_explicit(self, module):
data = rg_test_data.copy()
d2o_data = d2o.distributed_data_object(data)
shape = data.shape
x = RGSpace(shape, complexity=2, zerocenter=False)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
).transform(d2o_data),
np.array([[0.50541615 + 0.50558267j, -0.01458536 - 0.01646137j,
0.01649006 + 0.01990988j, 0.04668049 - 0.03351745j,
-0.04382765 - 0.06455639j, -0.05978564 + 0.01334044j],
[-0.05347464 + 0.04233343j, -0.05167177 + 0.00643947j,
-0.01995970 - 0.01168872j, 0.10653817 + 0.03885947j,
-0.03298075 - 0.00374715j, 0.00622585 - 0.01037453j],
[-0.01128964 - 0.02424692j, -0.03347793 - 0.0358814j,
-0.03924164 - 0.01978305j, 0.03821242 - 0.00435542j,
0.07533170 + 0.14590143j, -0.01493027 - 0.02664675j],
[0.02238926 + 0.06140625j, -0.06211313 + 0.03317753j,
0.01519073 + 0.02842563j, 0.00517758 + 0.08601604j,
-0.02246912 - 0.01942764j, -0.06627311 - 0.08763801j],
[-0.02492378 - 0.06097411j, 0.06365649 - 0.09346585j,
0.05031486 + 0.00858656j, -0.00881969 + 0.01842357j,
-0.01972641 - 0.00994365j, 0.05289453 - 0.06822038j],
[-0.01865586 - 0.08640926j, 0.03414096 - 0.02605602j,
-0.09492552 + 0.01306734j, 0.09355730 + 0.07553701j,
-0.02395259 - 0.02185743j, -0.03107832 - 0.04714527j]])
)
if __name__ == '__main__':
unittest.main()
......@@ -22,7 +22,7 @@ class GLTransform(Transform):
def transform(self, val, axes, **kwargs):
if self.domain.discrete:
val = self.calc_weight(val, power=-0.5)
val = self.domain.calc_weight(val, power=-0.5)
# shorthands for transform parameters
nlat = self.domain.paradict['nlat']
......
......@@ -24,7 +24,7 @@ class HPTransform(Transform):
niter = kwargs['niter'] if 'niter' in kwargs else 0
if self.domain.discrete:
val = self.calc_weight(val, power=-0.5)
val = self.domain.calc_weight(val, power=-0.5)
# shorthands for transform parameters
lmax = self.codomain.paradict['lmax']
......
from fftw import FFTW
from gfft import GFFT
from gfft import GFFT
from gltransform import GLTransform
from hptransform import HPTransform
from nifty.config import about, dependency_injector as gdi
from nifty import RGSpace
from nifty import RGSpace, GLSpace, HPSpace, LMSpace
import numpy as np
......@@ -21,7 +22,7 @@ class Transformation(object):
if isinstance(domain, RGSpace):
if not isinstance(codomain, RGSpace):
raise TypeError(about._errors.cstring(
"ERROR: The given codomain must be a rg_space."
"ERROR: codomain must be a RGSpace."
))
if not np.all(np.array(domain.paradict['shape']) ==
......@@ -66,7 +67,34 @@ class Transformation(object):
if not np.all(
np.absolute(np.array(domain.paradict['shape']) *
np.array(domain.distances) *
np.array(codomain.distances) - 1) < domain.epsilon):
np.array(
codomain.distances) - 1) < domain.epsilon):
return False
elif isinstance(domain, GLSpace):
if not isinstance(codomain, LMSpace):
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."
))
# shorthand
nlat = domain.paradict['nlat']
nlon = domain.paradict['nlon']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False
elif isinstance(domain, HPSpace):
if not isinstance(codomain, LMSpace):
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."
))
nside = domain.paradict['nside']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
if (3 * nside - 1 != lmax) or (lmax != mmax):
return False
else:
return False
......@@ -86,10 +114,10 @@ class RGRGTransformation(Transformation):
if module is None:
if gdi.get('pyfftw') is None:
if gdi.get('gfft') is None:
self._transform =\
self._transform = \
GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
self._transform =\
self._transform = \
GFFT(domain, codomain, gdi.get('gfft'))
self._transform = FFTW(domain, codomain)
else:
......@@ -100,12 +128,12 @@ class RGRGTransformation(Transformation):
raise RuntimeError("ERROR: pyfftw is not available.")
elif module == 'gfft':
if gdi.get('gfft') is not None:
self._transform =\
self._transform = \
GFFT(domain, codomain, gdi.get('gfft'))
else:
raise RuntimeError("ERROR: gfft is not available.")
elif module == 'gfft_dummy':
self._transform =\
self._transform = \
GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('Given FFT module is not known: ' +
......@@ -113,5 +141,38 @@ class RGRGTransformation(Transformation):
else:
raise ValueError("ERROR: Incompatible codomain!")
def transform(self, val, axes=None, **kwargs):
if self._transform.codomain.harmonic:
# correct for forward fft
val = self._transform.domain.calc_weight(val, power=1)
# Perform the transformation
Tval = self._transform.transform(val, axes, **kwargs)
if not self._transform.codomain.harmonic:
# correct for inverse fft
Tval = self._transform.codomain.calc_weight(Tval, power=-1)
return Tval
class GLLMTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if Transformation.check_codomain(domain, codomain):
self._transform = GLTransform(domain, codomain)
else:
raise ValueError("ERROR: Incompatible codomain!")
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
class HPLMTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if Transformation.check_codomain(domain, codomain):
self._transform = HPTransform(domain, codomain)
else:
raise ValueError("ERROR: Incompatible codomain!")
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
\ No newline at end of file
......@@ -3,7 +3,7 @@ import numpy as np
from nifty.rg import RGSpace
from nifty.lm import GLSpace, HPSpace, LMSpace
from transformation import RGRGTransformation
import transformation
class TransformationFactory(object):
......@@ -17,7 +17,13 @@ class TransformationFactory(object):
def _get_transform(self, domain, codomain, module):
if isinstance(domain, RGSpace):
return RGRGTransformation(domain, codomain, module)
return transformation.RGRGTransformation(domain, codomain, module)
elif isinstance(domain, GLSpace):
return transformation.GLLMTransformation(domain, codomain, module)
elif isinstance(domain, HPSpace):
return transformation.HPLMTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: unknown domain')
def create(self, domain, codomain, module=None):
key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment