Commit be3d197c authored by theos's avatar theos
Browse files

Merged fft_operator/transformation_factory into FFTOperator.

Started to consolidate the code base of LM <-> GL/HP transformations.
parent 007da99a
...@@ -5,7 +5,7 @@ from d2o import distributed_data_object,\ ...@@ -5,7 +5,7 @@ from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\ from nifty.config import about,\
nifty_configuration as gc,\ nifty_configuration as gc
from nifty.field_types import FieldType from nifty.field_types import FieldType
......
from nifty.config import about from nifty.config import about
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
from nifty.spaces import RGSpace,\
GLSpace,\
HPSpace,\
LMSpace
from nifty.operators.linear_operator import LinearOperator from nifty.operators.linear_operator import LinearOperator
from transformations import TransformationFactory from transformations import RGRGTransformation,\
LMGLTransformation,\
LMHPTransformation,\
GLLMTransformation,\
HPLMTransformation,\
TransformationCache
class FFTOperator(LinearOperator): class FFTOperator(LinearOperator):
# ---Class attributes---
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
GLSpace: LMSpace,
LMSpace: HPSpace,
}
transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
(HPSpace, LMSpace): HPLMTransformation,
(GLSpace, LMSpace): GLLMTransformation,
(LMSpace, HPSpace): LMHPTransformation,
(LMSpace, GLSpace): LMGLTransformation
}
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=None): def __init__(self, domain=(), field_type=(), target=None, module=None):
super(FFTOperator, self).__init__(domain=domain, super(FFTOperator, self).__init__(domain=domain,
field_type=field_type) field_type=field_type)
# Initialize domain and target
if len(self.domain) != 1: if len(self.domain) != 1:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator accepts only exactly one ' 'ERROR: TransformationOperator accepts only exactly one '
...@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator): ...@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator):
)) ))
if target is None: if target is None:
target = utilities.get_default_codomain(self.domain[0]) target = (self.get_default_codomain(self.domain[0]), )
self._target = self._parse_domain(target) self._target = self._parse_domain(target)
self._forward_transformation = TransformationFactory.create( # Create transformation instances
self.domain[0], self.target[0] try:
) forward_class = self.transformation_dictionary[
(self.domain[0].__class__, self.target[0].__class__)]
self._inverse_transformation = TransformationFactory.create( except KeyError:
self.target[0], self.domain[0] raise TypeError(about._errors.cstring(
) "ERROR: No forward transformation for domain-target pair "
"found."))
try:
backward_class = self.transformation_dictionary[
(self.target[0].__class__, self.domain[0].__class__)]
except KeyError:
raise TypeError(about._errors.cstring(
"ERROR: No backward transformation for domain-target pair "
"found."))
self._forward_transformation = TransformationCache.create(
forward_class, self.domain[0], self.target[0], module=module)
self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module)
def _times(self, x, spaces, types): def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
...@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator): ...@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator):
else: else:
axes = x.domain_axes[spaces[0]] axes = x.domain_axes[spaces[0]]
new_val = self._inverse_transformation.transform(x.val, axes=axes) new_val = self._backward_transformation.transform(x.val, axes=axes)
if spaces is None: if spaces is None:
result_domain = self.domain result_domain = self.domain
...@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator): ...@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator):
@property @property
def unitary(self): def unitary(self):
return True return True
# ---Added properties and methods---
@classmethod
def get_default_codomain(cls, domain):
domain_class = domain.__class__
try:
codomain_class = cls.default_codomain_dictionary[domain_class]
except KeyError:
raise TypeError(about._errors.cstring("ERROR: unknown domain"))
try:
transform_class = cls.transformation_dictionary[(domain_class,
codomain_class)]
except KeyError:
raise TypeError(about._errors.cstring(
"ERROR: No transformation for domain-codomain pair found."))
return transform_class.get_codomain(domain)
...@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation ...@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation from lmhptransformation import LMHPTransformation
from transformation_factory import TransformationFactory from transformation_cache import TransformationCache
\ No newline at end of file \ No newline at end of file
import numpy as np import numpy as np
from transformation import Transformation from nifty.config import dependency_injector as gdi,\
from d2o import distributed_data_object about
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf import lm_transformation_factory as ltf
gl = gdi.get('libsharp_wrapper_gl') libsharp = gdi.get('libsharp_wrapper_gl')
class GLLMTransformation(Transformation): class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if 'libsharp_wrapper_gl' not in gdi: if 'libsharp_wrapper_gl' not in gdi:
raise ImportError("The module libsharp is needed but not available") raise ImportError(about._errors.cstring(
"The module libsharp is needed but not available."))
if codomain is None: super(GLLMTransformation, self).__init__(domain, codomain, module)
self.domain = domain
self.codomain = self.get_codomain(domain) # ---Mandatory properties and methods---
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod @staticmethod
def get_codomain(domain): def get_codomain(domain):
...@@ -40,10 +38,12 @@ class GLLMTransformation(Transformation): ...@@ -40,10 +38,12 @@ class GLLMTransformation(Transformation):
A compatible codomain. A compatible codomain.
""" """
if domain is None: if domain is None:
raise ValueError('ERROR: cannot generate codomain for None') raise ValueError(about._errors.cstring(
"ERROR: cannot generate codomain for None"))
if not isinstance(domain, GLSpace): if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain needs to be a GLSpace') raise TypeError(about._errors.cstring(
"ERROR: domain needs to be a GLSpace"))
nlat = domain.nlat nlat = domain.nlat
lmax = nlat - 1 lmax = nlat - 1
...@@ -53,16 +53,18 @@ class GLLMTransformation(Transformation): ...@@ -53,16 +53,18 @@ class GLLMTransformation(Transformation):
else: else:
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex128) return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex128)
@staticmethod @classmethod
def check_codomain(domain, codomain): def check_codomain(cls, domain, codomain):
if not isinstance(domain, GLSpace): if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain is not a GLSpace') raise TypeError(about._errors.cstring(
"ERROR: domain is not a GLSpace"))
if codomain is None: if codomain is None:
return False return False
if not isinstance(codomain, LMSpace): if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.') raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."))
nlat = domain.nlat nlat = domain.nlat
nlon = domain.nlon nlon = domain.nlon
...@@ -74,74 +76,45 @@ class GLLMTransformation(Transformation): ...@@ -74,74 +76,45 @@ class GLLMTransformation(Transformation):
return True return True
def transform(self, val, axes=None, **kwargs): # ---Added properties and methods---
"""
GL -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if self.domain.discrete:
val = self.domain.weight(val, power=-0.5, axes=axes)
def _transformation_of_slice(self, inp):
# shorthands for transform parameters # shorthands for transform parameters
nlat = self.domain.nlat nlat = self.domain.nlat
nlon = self.domain.nlon nlon = self.domain.nlon
lmax = self.codomain.lmax lmax = self.codomain.lmax
mmax = self.codomain.mmax mmax = self.codomain.mmax
if isinstance(val, distributed_data_object): if issubclass(inp.dtype.type, np.complexfloating):
temp_val = val.get_full_data()
else: [resultReal, resultImag] = [self.libsharpMap2Alm(x,
temp_val = val nlat=nlat,
nlon=nlon,
return_val = None lmax=lmax,
mmax=mmax)
for slice_list in utilities.get_slice_list(temp_val.shape, axes): for x in (inp.real, inp.imag)]
if slice_list == [slice(None, None)]:
inp = temp_val resultReal = ltf.buildIdx(resultReal, lmax=lmax)
else: resultImag = ltf.buildIdx(resultImag, lmax=lmax)
if return_val is None: # construct correct complex dtype
return_val = np.empty_like(temp_val) one = resultReal.dtype.type(1)
inp = temp_val[slice_list] result_dtype = np.dtype(type(one + 1j))
if inp.dtype >= np.dtype('complex64'): result = np.empty_like(resultReal, dtype=result_dtype)
inpReal = self.GlMap2Alm( result.real = resultReal
np.real(inp).astype(np.float64, copy=False), nlat=nlat, result.imag = resultImag
nlon=nlon, lmax=lmax, mmax=mmax)
inpImg = self.GlMap2Alm(
np.imag(inp).astype(np.float64, copy=False), nlat=nlat,
nlon=nlon, lmax=lmax, mmax=mmax)
inpReal = ltf.buildIdx(inpReal, lmax=lmax)
inpImg = ltf.buildIdx(inpImg, lmax=lmax)
inp = inpReal + inpImg * 1j
else:
inp = self.GlMap2Alm(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
inp = ltf.buildIdx(inp, lmax=lmax)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
else: else:
return_val = return_val.astype(self.codomain.dtype, copy=False) result = self.libsharpMap2Alm(inp, nlat=nlat, nlon=nlon, lmax=lmax,
mmax=mmax)
result = ltf.buildIdx(result, lmax=lmax)
return return_val return result
def GlMap2Alm(self, inp, **kwargs): def libsharpMap2Alm(self, inp, **kwargs):
if inp.dtype == np.dtype('float32'): if inp.dtype == np.dtype('float32'):
return gl.map2alm_f(inp, kwargs) return libsharp.map2alm_f(inp, **kwargs)
elif inp.dtype == np.dtype('float64'):
return libsharp.map2alm(inp, **kwargs)
else: else:
return gl.map.alm(inp, kwargs) about.warnings.cprint("WARNING: performing dtype conversion for "
"libsharp compatibility.")
...@@ -7,17 +7,13 @@ from nifty import RGSpace, nifty_configuration ...@@ -7,17 +7,13 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation): class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if codomain is None: super(RGRGTransformation, self).__init__(domain, codomain, module)
codomain = self.get_codomain(domain)
else:
if not self.check_codomain(domain, codomain):
raise ValueError("ERROR: incompatible codomain!")
if module is None: if module is None:
if nifty_configuration['fft_module'] == 'pyfftw': if nifty_configuration['fft_module'] == 'pyfftw':
self._transform = FFTW(domain, codomain) self._transform = FFTW(domain, codomain)
elif nifty_configuration['fft_module'] == 'gfft' or \ elif (nifty_configuration['fft_module'] == 'gfft' or
nifty_configuration['fft_module'] == 'gfft_dummy': nifty_configuration['fft_module'] == 'gfft_dummy'):
self._transform = \ self._transform = \
GFFT(domain, GFFT(domain,
codomain, codomain,
...@@ -73,7 +69,9 @@ class RGRGTransformation(Transformation): ...@@ -73,7 +69,9 @@ class RGRGTransformation(Transformation):
distances = 1 / (np.array(domain.shape) * distances = 1 / (np.array(domain.shape) *
np.array(domain.distances)) np.array(domain.distances))
if dtype is None: if dtype is None:
dtype = np.complex # create a definitely complex dtype from the dtype of domain
one = domain.dtype.type(1)
dtype = np.dtype(type(one + 1j))
new_space = RGSpace(domain.shape, new_space = RGSpace(domain.shape,
zerocenter=zerocenter, zerocenter=zerocenter,
...@@ -86,7 +84,7 @@ class RGRGTransformation(Transformation): ...@@ -86,7 +84,7 @@ class RGRGTransformation(Transformation):
@staticmethod @staticmethod
def check_codomain(domain, codomain): def check_codomain(domain, codomain):
if not isinstance(domain, RGSpace): if not isinstance(domain, RGSpace):
raise TypeError('ERROR: domain must be a RGSpace') raise TypeError('ERROR: domain is not a RGSpace')
if codomain is None: if codomain is None:
return False return False
...@@ -101,6 +99,11 @@ class RGRGTransformation(Transformation): ...@@ -101,6 +99,11 @@ class RGRGTransformation(Transformation):
if domain.harmonic == codomain.harmonic: if domain.harmonic == codomain.harmonic:
return False return False
if codomain.harmonic and not issubclass(codomain.dtype.type,
np.complexfloating):
about.warnings.cprint(
"WARNING: codomain is harmonic but dtype is real.")
# Check if the distances match, i.e. dist' = 1 / (num * dist) # Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all( if not np.all(
np.absolute(np.array(domain.shape) * np.absolute(np.array(domain.shape) *
......
# -*- coding: utf-8 -*-
import abc
import numpy as np
import nifty.nifty_utilities as utilities
from transformation import Transformation
class SlicingTransformation(Transformation):
def transform(self, val, axes=None, **kwargs):
return_shape = np.array(val.shape)
return_shape[list(axes)] = self.codomain.shape
return_shape = tuple(return_shape)
return_val = None
for slice_list in utilities.get_slice_list(val.shape, axes):
if return_val is None:
return_val = val.copy_empty(dtype=self.codomain.dtype,
global_shape=return_shape)
data = val[slice_list]
data = data.get_full_data()
data = self._transformation_of_slice(data)
return_val[slice_list] = data
return return_val
@abc.abstractmethod
def _transformation_of_slice(self, inp):
raise NotImplementedError
import abc
class Transformation(object): class Transformation(object):
""" """
A generic transformation which defines a static check_codomain A generic transformation which defines a static check_codomain
method for all transforms. method for all transforms.
""" """
__metaclass__ = abc.ABCMeta
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
pass if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def get_codomain(domain, dtype=None, zerocenter=None, **kwargs):
raise NotImplementedError
@staticmethod
def check_codomain(domain, codomain):
raise NotImplementedError
def transform(self, val, axes=None, **kwargs): def transform(self, val, axes=None, **kwargs):
raise NotImplementedError raise NotImplementedError
class _TransformationCache(object):
def __init__(self):
self.cache = {}
def create(self, transformation_class, domain, codomain, module):
key = domain.__hash__() ^ ((codomain.__hash__()/111) ^
(module.__hash__())/179)
if key not in self.cache:
self.cache[key] = transformation_class(domain, codomain, module)
return self.cache[key]
TransformationCache = _TransformationCache()
from nifty.spaces import RGSpace, GLSpace, HPSpace, LMSpace
from rgrgtransformation import RGRGTransformation
from gllmtransformation import GLLMTransformation
from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation
class _TransformationFactory(object):
"""
Transform factory which generates transform objects
"""
def __init__(self):
# cache for storing the transform objects
self.cache = {}
def _get_transform(self, domain, codomain, module):
if isinstance(domain, RGSpace):
if isinstance(codomain, RGSpace):
return RGRGTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
elif isinstance(domain, GLSpace):
if isinstance(codomain, LMSpace):
return GLLMTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
elif isinstance(domain, HPSpace):