Commit a2690f08 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge nifty2go

parents 07686dcb 2d59db54
Pipeline #21100 passed with stage
in 4 minutes and 55 seconds
......@@ -16,15 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import Field, DomainTuple, nifty_utilities as utilities
from ..spaces import RGSpace, GLSpace, HPSpace, LMSpace
import numpy as np
from .. import DomainTuple
from ..spaces import RGSpace
from .linear_operator import LinearOperator
from .fft_operator_support import RGRGTransformation,\
LMGLTransformation,\
LMHPTransformation,\
GLLMTransformation,\
HPLMTransformation
from .fft_operator_support import RGRGTransformation, SphericalTransformation
class FFTOperator(LinearOperator):
......@@ -76,21 +72,6 @@ class FFTOperator(LinearOperator):
if "domain" or "target" are not of the proper type.
"""
# ---Class attributes---
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
GLSpace: LMSpace,
LMSpace: GLSpace,
}
transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
(HPSpace, LMSpace): HPLMTransformation,
(GLSpace, LMSpace): GLLMTransformation,
(LMSpace, HPSpace): LMHPTransformation,
(LMSpace, GLSpace): LMGLTransformation
}
def __init__(self, domain, target=None, space=None):
super(FFTOperator, self).__init__()
......@@ -101,43 +82,42 @@ class FFTOperator(LinearOperator):
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if (space<0) or space>=len(self._domain.domains):
if space < 0 or space >= len(self._domain.domains):
raise ValueError("space index out of range")
self._space = space
adom = self.domain[self._space]
if target is None:
target = [ dom for dom in self.domain ]
target = [dom for dom in self.domain]
target[self._space] = adom.get_default_codomain()
self._target = DomainTuple.make(target)
atgt = self._target[self._space]
atgt = self._target[space]
adom.check_codomain(atgt)
atgt.check_codomain(adom)
# Create transformation instances
forward_class = self.transformation_dictionary[
(adom.__class__, atgt.__class__)]
backward_class = self.transformation_dictionary[
(atgt.__class__, adom.__class__)]
self._forward_transformation = forward_class(adom, atgt)
self._backward_transformation = backward_class(atgt, adom)
def _times_helper(self, x, other, trafo):
axes = x.domain.axes[self._space]
new_val, fct = trafo.transform(x.val, axes=axes)
res = Field(other, new_val, copy=False)
if fct != 1.:
res *= fct
if self._target[space].harmonic:
pdom, hdom = (self._domain, self._target)
else:
pdom, hdom = (self._target, self._domain)
if isinstance(pdom[space], RGSpace):
self._trafo = RGRGTransformation(pdom, hdom, space)
else:
self._trafo = SphericalTransformation(pdom, hdom, space)
def _times_helper(self, x):
if issubclass(x.dtype.type, np.complexfloating):
res = (self._trafo.transform(x.real) +
1j * self._trafo.transform(x.imag))
else:
res = self._trafo.transform(x)
return res
def _times(self, x):
return self._times_helper(x, self.target, self._forward_transformation)
return self._times_helper(x)
def _adjoint_times(self, x):
return self._times_helper(x, self.domain, self._backward_transformation)
return self._times_helper(x)
@property
def domain(self):
......@@ -149,5 +129,4 @@ class FFTOperator(LinearOperator):
@property
def unitary(self):
return (self._forward_transformation.unitary and
self._backward_transformation.unitary)
return self._trafo.unitary
......@@ -21,76 +21,72 @@ import numpy as np
from .. import nifty_utilities as utilities
from ..low_level_library import hartley
from ..dobj import to_ndarray as to_np, from_ndarray as from_np
from ..field import Field
from ..spaces.gl_space import GLSpace
class Transformation(object):
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
def unitary(self):
raise NotImplementedError
def transform(self, val, axes=None):
raise NotImplementedError
def __init__(self, pdom, hdom, space):
self.pdom = pdom
self.hdom = hdom
self.space = space
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None):
def __init__(self, pdom, hdom, space):
import pyfftw
super(RGRGTransformation, self).__init__(domain, codomain)
super(RGRGTransformation, self).__init__(pdom, hdom, space)
pyfftw.interfaces.cache.enable()
# correct for forward/inverse fft.
# naively one would set power to 0.5 here in order to
# apply effectively a factor of 1/sqrt(N) to the field.
# BUT: the pixel volumes of the domain and codomain are different.
# Hence, in order to produce the same scalar product, power==1.
self.fct_p2h = pdom[space].scalar_dvol()
self.fct_h2p = 1./(pdom[space].scalar_dvol()*hdom[space].dim)
@property
def unitary(self):
return True
def transform(self, val, axes=None):
def transform(self, x):
"""
RG -> RG transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
x : Field
The field to be transformed
"""
# correct for forward/inverse fft.
# naively one would set power to 0.5 here in order to
# apply effectively a factor of 1/sqrt(N) to the field.
# BUT: the pixel volumes of the domain and codomain are different.
# Hence, in order to produce the same scalar product, power==1.
if self.codomain.harmonic:
fct = self.domain.scalar_dvol()
axes = x.domain.axes[self.space]
p2h = x.domain == self.pdom
if p2h:
Tval = Field(self.hdom, from_np(hartley(to_np(x.val), axes)))
else:
fct = 1./(self.codomain.scalar_dvol()*self.domain.dim)
Tval = Field(self.pdom, from_np(hartley(to_np(x.val), axes)))
fct = self.fct_p2h if p2h else self.fct_h2p
if fct != 1:
Tval *= fct
# Perform the transformation
if issubclass(val.dtype.type, np.complexfloating):
Tval = from_np(hartley(to_np(val.real), axes) +
1j*hartley(to_np(val.imag), axes))
else:
Tval = from_np(hartley(to_np(val), axes))
return Tval, fct
return Tval
class SlicingTransformation(Transformation):
def transform(self, val, axes=None):
val = to_np(val)
return_shape = np.array(val.shape)
return_shape[list(axes)] = self.codomain.shape
return_val = np.empty(tuple(return_shape), dtype=val.dtype)
for slice in utilities.get_slice_list(val.shape, axes):
return_val[slice] = self._transformation_of_slice(val[slice])
return from_np(return_val), 1.
def transform(self, x):
p2h = x.domain == self.pdom
if p2h:
res = Field(self.hdom, dtype=x.dtype)
for slice in utilities.get_slice_list(x.shape,
x.domain.axes[self.space]):
res.val[slice] = from_np(self._slice_p2h(to_np(x.val[slice])))
else:
res = Field(self.pdom, dtype=x.dtype)
def _transformation_of_slice(self, inp):
raise NotImplementedError
for slice in utilities.get_slice_list(x.shape,
x.domain.axes[self.space]):
res.val[slice] = from_np(self._slice_h2p(to_np(x.val[slice])))
return res
def buildLm(nr, lmax):
......@@ -108,94 +104,29 @@ def buildIdx(nr, lmax):
return res
class HPLMTransformation(SlicingTransformation):
@property
def unitary(self):
return False
def _transformation_of_slice(self, inp):
from pyHealpix import map2alm
lmax = self.codomain.lmax
mmax = lmax
if issubclass(inp.dtype.type, np.complexfloating):
rr = map2alm(inp.real, lmax, mmax)
rr = buildIdx(rr, lmax=lmax)
ri = map2alm(inp.imag, lmax, mmax)
ri = buildIdx(ri, lmax=lmax)
return rr + 1j*ri
else:
rr = map2alm(inp, lmax, mmax)
return buildIdx(rr, lmax=lmax)
class LMHPTransformation(SlicingTransformation):
@property
def unitary(self):
return False
def _transformation_of_slice(self, inp):
from pyHealpix import alm2map
nside = self.codomain.nside
lmax = self.domain.lmax
mmax = lmax
if issubclass(inp.dtype.type, np.complexfloating):
rr = buildLm(inp.real, lmax=lmax)
ri = buildLm(inp.imag, lmax=lmax)
rr = alm2map(rr, lmax, mmax, nside)
ri = alm2map(ri, lmax, mmax, nside)
return rr + 1j*ri
else:
rr = buildLm(inp, lmax=lmax)
return alm2map(rr, lmax, mmax, nside)
class GLLMTransformation(SlicingTransformation):
@property
def unitary(self):
return False
def _transformation_of_slice(self, inp):
class SphericalTransformation(SlicingTransformation):
def __init__(self, pdom, hdom, space):
super(SphericalTransformation, self).__init__(pdom, hdom, space)
from pyHealpix import sharpjob_d
lmax = self.codomain.lmax
mmax = self.codomain.mmax
sjob = sharpjob_d()
sjob.set_Gauss_geometry(self.domain.nlat, self.domain.nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
rr = sjob.map2alm(inp.real)
rr = buildIdx(rr, lmax=lmax)
ri = sjob.map2alm(inp.imag)
ri = buildIdx(ri, lmax=lmax)
return rr + 1j*ri
self.lmax = self.hdom[self.space].lmax
self.sjob = sharpjob_d()
self.sjob.set_triangular_alm_info(self.hdom[self.space].lmax,
self.hdom[self.space].mmax)
if isinstance(self.pdom[self.space], GLSpace):
self.sjob.set_Gauss_geometry(self.pdom[self.space].nlat,
self.pdom[self.space].nlon)
else:
rr = sjob.map2alm(inp)
return buildIdx(rr, lmax=lmax)
self.sjob.set_Healpix_geometry(self.pdom[self.space].nside)
class LMGLTransformation(SlicingTransformation):
@property
def unitary(self):
return False
def _transformation_of_slice(self, inp):
from pyHealpix import sharpjob_d
def _slice_p2h(self, inp):
rr = self.sjob.map2alm(inp)
return buildIdx(rr, lmax=self.lmax)
lmax = self.domain.lmax
mmax = self.domain.mmax
sjob = sharpjob_d()
sjob.set_Gauss_geometry(self.codomain.nlat, self.codomain.nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
rr = buildLm(inp.real, lmax=lmax)
ri = buildLm(inp.imag, lmax=lmax)
return sjob.alm2map(rr) + 1j*sjob.alm2map(ri)
else:
result = buildLm(inp, lmax=lmax)
return sjob.alm2map(result)
def _slice_h2p(self, inp):
result = buildLm(inp, lmax=self.lmax)
return self.sjob.alm2map(result)
......@@ -35,15 +35,12 @@ def _get_rtol(tp):
class FFTOperatorTests(unittest.TestCase):
@expand(product([16, ], [0.1, 1, 3.7],
[np.float64, np.float32, np.complex64, np.complex128],
['real', 'complex']))
def test_fft1D(self, dim1, d, itp, base):
[np.float64, np.float32, np.complex64, np.complex128]))
def test_fft1D(self, dim1, d, itp):
tol = _get_rtol(itp)
a = RGSpace(dim1, distances=d)
b = RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
fft = FFTOperator(domain=a, target=b)
fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
np.random.seed(16)
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
......@@ -53,17 +50,13 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([12, 15], [9, 12], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.float32, np.complex64, np.complex128],
['real', 'complex']))
def test_fft2D(self, dim1, dim2, d1, d2,
itp, base):
[np.float64, np.float32, np.complex64, np.complex128]))
def test_fft2D(self, dim1, dim2, d1, d2, itp):
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], distances=[d1, d2])
b = RGSpace([dim1, dim2],
distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
fft = FFTOperator(domain=a, target=b)
fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
inp = Field.from_random(domain=a, random_type='normal', std=7, mean=3,
dtype=itp)
......@@ -71,10 +64,8 @@ class FFTOperatorTests(unittest.TestCase):
assert_allclose(to_np(inp.val), to_np(out.val), rtol=tol, atol=tol)
@expand(product([0, 1, 2],
[np.float64, np.float32, np.complex64, np.complex128],
['real', 'complex']))
def test_composed_fft(self, index, dtype,
base):
[np.float64, np.float32, np.complex64, np.complex128]))
def test_composed_fft(self, index, dtype):
tol = _get_rtol(dtype)
a = [a1, a2, a3] = [RGSpace((32,)), RGSpace((4, 4)), RGSpace((5, 6))]
fft = FFTOperator(domain=a, space=index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment