Commit 72225dd1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent eaada6df
......@@ -27,13 +27,9 @@ import pyHealpix
class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None):
super(GLLMTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
......@@ -42,7 +38,7 @@ class GLLMTransformation(SlicingTransformation):
def get_codomain(cls, domain):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ an instance of the :py:class:`lm_space` class.
reasonable, i.e.\ an instance of the :py:class:`LMSpace` class.
Parameters
----------
......@@ -58,28 +54,17 @@ class GLLMTransformation(SlicingTransformation):
if not isinstance(domain, GLSpace):
raise TypeError("domain needs to be a GLSpace")
nlat = domain.nlat
lmax = nlat - 1
result = LMSpace(lmax=lmax)
return result
return LMSpace(lmax=domain.nlat-1,mmax=(domain.nlon-1)//2)
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, GLSpace):
raise TypeError("domain is not a GLSpace")
if not isinstance(codomain, LMSpace):
raise TypeError("codomain must be a LMSpace.")
nlat = domain.nlat
nlon = domain.nlon
lmax = codomain.lmax
mmax = codomain.mmax
super(GLLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
def _transformation_of_slice(self, inp):
nlat = self.domain.nlat
nlon = self.domain.nlon
lmax = self.codomain.lmax
......@@ -89,16 +74,11 @@ class GLLMTransformation(SlicingTransformation):
sjob.set_Gauss_geometry(nlat, nlon)
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal, resultImag] = [sjob.map2alm(x)
for x in (inp.real, inp.imag)]
[resultReal,
resultImag] = [lm_transformation_helper.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
return self._combine_complex_result(
lm_transformation_helper.buildIdx(sjob.map2alm(inp.real),
lmax=lmax),
lm_transformation_helper.buildIdx(sjob.map2alm(inp.imag),
lmax=lmax))
else:
result = sjob.map2alm(inp)
result = lm_transformation_helper.buildIdx(result, lmax=lmax)
return result
return lm_transformation_helper.buildIdx(sjob.map2alm(inp),
lmax=lmax)
......@@ -28,13 +28,9 @@ import pyHealpix
class HPLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None):
super(HPLMTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
......@@ -59,25 +55,17 @@ class HPLMTransformation(SlicingTransformation):
if not isinstance(domain, HPSpace):
raise TypeError("domain needs to be a HPSpace")
lmax = 2*domain.nside
result = LMSpace(lmax=lmax)
return result
return LMSpace(lmax=2*domain.nside)
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, HPSpace):
raise TypeError("domain is not a HPSpace")
if not isinstance(codomain, LMSpace):
raise TypeError("codomain must be a LMSpace.")
lmax = codomain.lmax
nside = domain.nside
super(HPLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
def _transformation_of_slice(self, inp):
lmax = self.codomain.lmax
mmax = lmax
......
......@@ -64,28 +64,18 @@ class LMGLTransformation(SlicingTransformation):
if not isinstance(domain, LMSpace):
raise TypeError("domain needs to be a LMSpace")
nlat = domain.lmax + 1
nlon = domain.lmax*2 + 1
result = GLSpace(nlat=nlat, nlon=nlon)
return result
return GLSpace(nlat=domain.lmax+1, nlon=domain.mmax*2+1)
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError("domain is not a LMSpace")
if not isinstance(codomain, GLSpace):
raise TypeError("codomain must be a GLSpace.")
nlat = codomain.nlat
nlon = codomain.nlon
lmax = domain.lmax
mmax = domain.mmax
super(LMGLTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
def _transformation_of_slice(self, inp):
nlat = self.codomain.nlat
nlon = self.codomain.nlon
lmax = self.domain.lmax
......
......@@ -26,13 +26,9 @@ import pyHealpix
class LMHPTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None):
super(LMHPTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
@property
def unitary(self):
return False
......@@ -70,16 +66,12 @@ class LMHPTransformation(SlicingTransformation):
def check_codomain(cls, domain, codomain):
if not isinstance(domain, LMSpace):
raise TypeError("domain is not a LMSpace.")
if not isinstance(codomain, HPSpace):
raise TypeError("codomain must be a HPSpace.")
nside = codomain.nside
lmax = domain.lmax
super(LMHPTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
def _transformation_of_slice(self, inp):
nside = self.codomain.nside
lmax = self.domain.lmax
mmax = lmax
......
......@@ -37,7 +37,7 @@ class Transform(object):
self.domain = domain
self.codomain = codomain
def transform(self, val, axes, **kwargs):
def transform(self, val, axes):
"""
A generic ff-transform function.
......@@ -66,25 +66,22 @@ class SerialFFT(Transform):
pyfftw.interfaces.cache.enable()
def transform(self, val, axes, **kwargs):
def transform(self, val, axes):
"""
The scalar FFT transform function.
Parameters
----------
val : distributed_data_object or numpy.ndarray
val : or numpy.ndarray
The value-array of the field which is supposed to
be transformed.
axes: tuple, None
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are passed to the create_mpi_plan routine.
Returns
-------
result : np.ndarray or distributed_data_object
result : numpy.ndarray
Fourier-transformed pendant of the input field.
"""
......@@ -93,24 +90,13 @@ class SerialFFT(Transform):
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
return_val = np.empty(val.shape, dtype=np.complex)
local_val = val
result_data = self._atomic_transform(local_val=local_val,
axes=axes,
local_offset_Q=False)
return_val=result_data
return return_val
return self._atomic_transform(local_val=val,
axes=axes,
local_offset_Q=False)
def _atomic_transform(self, local_val, axes, local_offset_Q):
# perform the transformation
if self.codomain.harmonic:
result_val = pyfftw.interfaces.numpy_fft.fftn(
local_val, axes=axes)
return pyfftw.interfaces.numpy_fft.fftn(local_val, axes=axes)
else:
result_val = pyfftw.interfaces.numpy_fft.ifftn(
local_val, axes=axes)
return result_val
return pyfftw.interfaces.numpy_fft.ifftn(local_val, axes=axes)
......@@ -96,7 +96,7 @@ class RGRGTransformation(Transformation):
super(RGRGTransformation, cls).check_codomain(domain, codomain)
def transform(self, val, axes=None, **kwargs):
def transform(self, val, axes=None):
"""
RG -> RG transform method.
......@@ -119,22 +119,18 @@ class RGRGTransformation(Transformation):
# Perform the transformation
if issubclass(val.dtype.type, np.complexfloating):
Tval_real = self._transform.transform(val.real, axes,
**kwargs)
Tval_imag = self._transform.transform(val.imag, axes,
**kwargs)
Tval_real = self._transform.transform(val.real, axes)
Tval_imag = self._transform.transform(val.imag, axes)
if self.codomain.harmonic:
Tval_real.real += Tval_real.imag
Tval_real.imag = \
Tval_imag.real + Tval_imag.imag
Tval_real.imag = Tval_imag.real + Tval_imag.imag
else:
Tval_real.real -= Tval_real.imag
Tval_real.imag = \
Tval_imag.real - Tval_imag.imag
Tval_real.imag = Tval_imag.real - Tval_imag.imag
Tval = Tval_real
else:
Tval = self._transform.transform(val, axes, **kwargs)
Tval = self._transform.transform(val, axes)
if self.codomain.harmonic:
Tval.real += Tval.imag
else:
......
......@@ -25,35 +25,22 @@ from .transformation import Transformation
class SlicingTransformation(Transformation):
def transform(self, val, axes=None, **kwargs):
def transform(self, val, axes=None):
return_shape = np.array(val.shape)
return_shape[list(axes)] = self.codomain.shape
return_shape = tuple(return_shape)
return_val = None
for slice_list in utilities.get_slice_list(val.shape, axes):
if return_val is None:
return_val = np.empty(return_shape,dtype=val.dtype)
data = val[slice_list]
data = self._transformation_of_slice(data, **kwargs)
return_val[slice_list] = data
return_val[slice_list] = self._transformation_of_slice(
val[slice_list])
return return_val
def _combine_complex_result(self, resultReal, resultImag):
# construct correct complex dtype
one = resultReal.dtype.type(1)
result_dtype = np.dtype(type(one + 1j))
result = np.empty_like(resultReal, dtype=result_dtype)
result.real = resultReal
result.imag = resultImag
return result
return resultReal + 1j*resultImag
@abc.abstractmethod
def _transformation_of_slice(self, inp):
......
......@@ -48,5 +48,5 @@ class Transformation(with_metaclass(abc.ABCMeta, type('NewBase', (object,), {}))
def check_codomain(cls, domain, codomain):
pass
def transform(self, val, axes=None, **kwargs):
def transform(self, val, axes=None):
raise NotImplementedError
......@@ -113,8 +113,8 @@ class LinearOperator(with_metaclass(
def default_spaces(self):
return self._default_spaces
def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs)
def __call__(self, x, spaces=None):
return self.times(x, spaces)
def times(self, x, spaces=None):
""" Applies the Operator to a given Field.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment