Commit 7fc8256f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent 2511a1d0
......@@ -17,16 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from .... import GLSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
import pyHealpix
class GLLMTransformation(SlicingTransformation):
def __init__(self, domain, codomain=None):
super(GLLMTransformation, self).__init__(domain, codomain)
......@@ -35,20 +31,18 @@ class GLLMTransformation(SlicingTransformation):
return False
def _transformation_of_slice(self, inp):
nlat = self.domain.nlat
nlon = self.domain.nlon
lmax = self.codomain.lmax
mmax = self.codomain.mmax
sjob = pyHealpix.sharpjob_d()
sjob.set_Gauss_geometry(nlat, nlon)
sjob.set_Gauss_geometry(self.domain.nlat, self.domain.nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
return \
lm_transformation_helper.buildIdx(sjob.map2alm(inp.real),
lmax=lmax)\
+1j*lm_transformation_helper.buildIdx(sjob.map2alm(inp.imag),
lmax=lmax)
rr = sjob.map2alm(inp.real)
rr = lm_transformation_helper.buildIdx(rr, lmax=lmax)
ri = sjob.map2alm(inp.imag)
ri = lm_transformation_helper.buildIdx(ri, lmax=lmax)
return rr + 1j*ri
else:
return lm_transformation_helper.buildIdx(sjob.map2alm(inp),
lmax=lmax)
rr = sjob.map2alm(inp)
return lm_transformation_helper.buildIdx(rr, lmax=lmax)
......@@ -17,17 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from .... import HPSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
import pyHealpix
class HPLMTransformation(SlicingTransformation):
def __init__(self, domain, codomain=None):
super(HPLMTransformation, self).__init__(domain, codomain)
......@@ -40,18 +35,12 @@ class HPLMTransformation(SlicingTransformation):
mmax = lmax
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal,
resultImag] = [pyHealpix.map2alm(x, lmax, mmax)
for x in (inp.real, inp.imag)]
[resultReal,
resultImag] = [lm_transformation_helper.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
result = resultReal +1j*resultImag
rr = pyHealpix.map2alm(inp.real, lmax, mmax)
rr = lm_transformation_helper.buildIdx(rr, lmax=lmax)
ri = pyHealpix.map2alm(inp.imag, lmax, mmax)
ri = lm_transformation_helper.buildIdx(ri, lmax=lmax)
return rr + 1j*ri
else:
result = pyHealpix.map2alm(inp, lmax, mmax)
result = lm_transformation_helper.buildIdx(result, lmax=lmax)
return result
rr = pyHealpix.map2alm(inp, lmax, mmax)
return lm_transformation_helper.buildIdx(rr, lmax=lmax)
......@@ -17,48 +17,31 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from .... import GLSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
import pyHealpix
class LMGLTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None):
super(LMGLTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
def _transformation_of_slice(self, inp):
nlat = self.codomain.nlat
nlon = self.codomain.nlon
lmax = self.domain.lmax
mmax = self.domain.mmax
sjob = pyHealpix.sharpjob_d()
sjob.set_Gauss_geometry(nlat, nlon)
sjob.set_Gauss_geometry(self.codomain.nlat, self.codomain.nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal,
resultImag] = [lm_transformation_helper.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [sjob.alm2map(x)
for x in [resultReal, resultImag]]
result = resultReal + 1j*resultImag
rr = lm_transformation_helper.buildLm(inp.real, lmax=lmax)
ri = lm_transformation_helper.buildLm(inp.imag, lmax=lmax)
return sjob.alm2map(rr) + 1j*sjob.alm2map(ri)
else:
result = lm_transformation_helper.buildLm(inp, lmax=lmax)
result = sjob.alm2map(result)
return result
return sjob.alm2map(result)
......@@ -17,10 +17,8 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from .... import HPSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
import pyHealpix
......@@ -39,17 +37,12 @@ class LMHPTransformation(SlicingTransformation):
mmax = lmax
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal,
resultImag] = [lm_transformation_helper.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [pyHealpix.alm2map(x, lmax, mmax, nside)
for x in [resultReal, resultImag]]
result = resultReal + 1j*resultImag
rr = lm_transformation_helper.buildLm(inp.real, lmax=lmax)
ri = lm_transformation_helper.buildLm(inp.imag, lmax=lmax)
rr = pyHealpix.alm2map(rr, lmax, mmax, nside)
ri = pyHealpix.alm2map(ri, lmax, mmax, nside)
return rr + 1j*ri
else:
result = lm_transformation_helper.buildLm(inp, lmax=lmax)
result = pyHealpix.alm2map(result, lmax, mmax, nside)
return result
rr = lm_transformation_helper.buildLm(inp, lmax=lmax)
return pyHealpix.alm2map(rr, lmax, mmax, nside)
......@@ -16,54 +16,18 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import range
from builtins import object
import warnings
import numpy as np
from .... import nifty_utilities as utilities
from functools import reduce
from builtins import object, range
import pyfftw
class Transform(object):
class SerialFFT(object):
"""
A generic fft object without any implementation.
The pyfftw pendant of a fft object.
"""
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
def transform(self, val, axes):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
class SerialFFT(Transform):
"""
The numpy fft pendant of a fft object.
"""
def __init__(self, domain, codomain):
super(SerialFFT, self).__init__(domain, codomain)
pyfftw.interfaces.cache.enable()
def transform(self, val, axes):
......@@ -72,7 +36,7 @@ class SerialFFT(Transform):
Parameters
----------
val : or numpy.ndarray
val : numpy.ndarray
The value-array of the field which is supposed to
be transformed.
......@@ -90,13 +54,7 @@ class SerialFFT(Transform):
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
return self._atomic_transform(local_val=val,
axes=axes,
local_offset_Q=False)
def _atomic_transform(self, local_val, axes, local_offset_Q):
# perform the transformation
if self.codomain.harmonic:
return pyfftw.interfaces.numpy_fft.fftn(local_val, axes=axes)
return pyfftw.interfaces.numpy_fft.fftn(val, axes=axes)
else:
return pyfftw.interfaces.numpy_fft.ifftn(local_val, axes=axes)
return pyfftw.interfaces.numpy_fft.ifftn(val, axes=axes)
......@@ -29,12 +29,9 @@ class SlicingTransformation(Transformation):
return_shape = np.array(val.shape)
return_shape[list(axes)] = self.codomain.shape
return_shape = tuple(return_shape)
return_val = None
return_val = np.empty(return_shape,dtype=val.dtype)
for slice_list in utilities.get_slice_list(val.shape, axes):
if return_val is None:
return_val = np.empty(return_shape,dtype=val.dtype)
return_val[slice_list] = self._transformation_of_slice(
val[slice_list])
return return_val
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment