Commit be3d197c authored by theos's avatar theos
Browse files

Merged fft_operator/transformation_factory into FFTOperator.

Started to consolidate the code base of LM <-> GL/HP transformations.
parent 007da99a
......@@ -5,7 +5,7 @@ from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\
nifty_configuration as gc,\
nifty_configuration as gc
from nifty.field_types import FieldType
......
from nifty.config import about
import nifty.nifty_utilities as utilities
from nifty.spaces import RGSpace,\
GLSpace,\
HPSpace,\
LMSpace
from nifty.operators.linear_operator import LinearOperator
from transformations import TransformationFactory
from transformations import RGRGTransformation,\
LMGLTransformation,\
LMHPTransformation,\
GLLMTransformation,\
HPLMTransformation,\
TransformationCache
class FFTOperator(LinearOperator):
# ---Class attributes---
default_codomain_dictionary = {RGSpace: RGSpace,
HPSpace: LMSpace,
GLSpace: LMSpace,
LMSpace: HPSpace,
}
transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
(HPSpace, LMSpace): HPLMTransformation,
(GLSpace, LMSpace): GLLMTransformation,
(LMSpace, HPSpace): LMHPTransformation,
(LMSpace, GLSpace): LMGLTransformation
}
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=None):
def __init__(self, domain=(), field_type=(), target=None, module=None):
super(FFTOperator, self).__init__(domain=domain,
field_type=field_type)
# Initialize domain and target
if len(self.domain) != 1:
raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator accepts only exactly one '
......@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator):
))
if target is None:
target = utilities.get_default_codomain(self.domain[0])
target = (self.get_default_codomain(self.domain[0]), )
self._target = self._parse_domain(target)
self._forward_transformation = TransformationFactory.create(
self.domain[0], self.target[0]
)
self._inverse_transformation = TransformationFactory.create(
self.target[0], self.domain[0]
)
# Create transformation instances
try:
forward_class = self.transformation_dictionary[
(self.domain[0].__class__, self.target[0].__class__)]
except KeyError:
raise TypeError(about._errors.cstring(
"ERROR: No forward transformation for domain-target pair "
"found."))
try:
backward_class = self.transformation_dictionary[
(self.target[0].__class__, self.domain[0].__class__)]
except KeyError:
raise TypeError(about._errors.cstring(
"ERROR: No backward transformation for domain-target pair "
"found."))
self._forward_transformation = TransformationCache.create(
forward_class, self.domain[0], self.target[0], module=module)
self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module)
def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
......@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator):
else:
axes = x.domain_axes[spaces[0]]
new_val = self._inverse_transformation.transform(x.val, axes=axes)
new_val = self._backward_transformation.transform(x.val, axes=axes)
if spaces is None:
result_domain = self.domain
......@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator):
@property
def unitary(self):
return True
# ---Added properties and methods---
@classmethod
def get_default_codomain(cls, domain):
domain_class = domain.__class__
try:
codomain_class = cls.default_codomain_dictionary[domain_class]
except KeyError:
raise TypeError(about._errors.cstring("ERROR: unknown domain"))
try:
transform_class = cls.transformation_dictionary[(domain_class,
codomain_class)]
except KeyError:
raise TypeError(about._errors.cstring(
"ERROR: No transformation for domain-codomain pair found."))
return transform_class.get_codomain(domain)
......@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation
from transformation_factory import TransformationFactory
\ No newline at end of file
from transformation_cache import TransformationCache
\ No newline at end of file
import numpy as np
from transformation import Transformation
from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty.config import dependency_injector as gdi,\
about
from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory as ltf
gl = gdi.get('libsharp_wrapper_gl')
libsharp = gdi.get('libsharp_wrapper_gl')
class GLLMTransformation(Transformation):
class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if 'libsharp_wrapper_gl' not in gdi:
raise ImportError("The module libsharp is needed but not available")
raise ImportError(about._errors.cstring(
"The module libsharp is needed but not available."))
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
super(GLLMTransformation, self).__init__(domain, codomain, module)
# ---Mandatory properties and methods---
@staticmethod
def get_codomain(domain):
......@@ -40,10 +38,12 @@ class GLLMTransformation(Transformation):
A compatible codomain.
"""
if domain is None:
raise ValueError('ERROR: cannot generate codomain for None')
raise ValueError(about._errors.cstring(
"ERROR: cannot generate codomain for None"))
if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain needs to be a GLSpace')
raise TypeError(about._errors.cstring(
"ERROR: domain needs to be a GLSpace"))
nlat = domain.nlat
lmax = nlat - 1
......@@ -53,16 +53,18 @@ class GLLMTransformation(Transformation):
else:
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.complex128)
@staticmethod
def check_codomain(domain, codomain):
@classmethod
def check_codomain(cls, domain, codomain):
if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain is not a GLSpace')
raise TypeError(about._errors.cstring(
"ERROR: domain is not a GLSpace"))
if codomain is None:
return False
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
raise TypeError(about._errors.cstring(
"ERROR: codomain must be a LMSpace."))
nlat = domain.nlat
nlon = domain.nlon
......@@ -74,74 +76,45 @@ class GLLMTransformation(Transformation):
return True
def transform(self, val, axes=None, **kwargs):
"""
GL -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if self.domain.discrete:
val = self.domain.weight(val, power=-0.5, axes=axes)
# ---Added properties and methods---
def _transformation_of_slice(self, inp):
# shorthands for transform parameters
nlat = self.domain.nlat
nlon = self.domain.nlon
lmax = self.codomain.lmax
mmax = self.codomain.mmax
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
else:
temp_val = val
if issubclass(inp.dtype.type, np.complexfloating):
return_val = None
[resultReal, resultImag] = [self.libsharpMap2Alm(x,
nlat=nlat,
nlon=nlon,
lmax=lmax,
mmax=mmax)
for x in (inp.real, inp.imag)]
for slice_list in utilities.get_slice_list(temp_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = temp_val
else:
if return_val is None:
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
if inp.dtype >= np.dtype('complex64'):
inpReal = self.GlMap2Alm(
np.real(inp).astype(np.float64, copy=False), nlat=nlat,
nlon=nlon, lmax=lmax, mmax=mmax)
inpImg = self.GlMap2Alm(
np.imag(inp).astype(np.float64, copy=False), nlat=nlat,
nlon=nlon, lmax=lmax, mmax=mmax)
inpReal = ltf.buildIdx(inpReal, lmax=lmax)
inpImg = ltf.buildIdx(inpImg, lmax=lmax)
inp = inpReal + inpImg * 1j
else:
inp = self.GlMap2Alm(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
inp = ltf.buildIdx(inp, lmax=lmax)
if slice_list == [slice(None, None)]:
return_val = inp
else:
return_val[slice_list] = inp
resultReal = ltf.buildIdx(resultReal, lmax=lmax)
resultImag = ltf.buildIdx(resultImag, lmax=lmax)
# construct correct complex dtype
one = resultReal.dtype.type(1)
result_dtype = np.dtype(type(one + 1j))
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
result = np.empty_like(resultReal, dtype=result_dtype)
result.real = resultReal
result.imag = resultImag
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
result = self.libsharpMap2Alm(inp, nlat=nlat, nlon=nlon, lmax=lmax,
mmax=mmax)
result = ltf.buildIdx(result, lmax=lmax)
return return_val
return result
def GlMap2Alm(self, inp, **kwargs):
def libsharpMap2Alm(self, inp, **kwargs):
if inp.dtype == np.dtype('float32'):
return gl.map2alm_f(inp, kwargs)
return libsharp.map2alm_f(inp, **kwargs)
elif inp.dtype == np.dtype('float64'):
return libsharp.map2alm(inp, **kwargs)
else:
return gl.map.alm(inp, kwargs)
about.warnings.cprint("WARNING: performing dtype conversion for "
"libsharp compatibility.")
......@@ -7,17 +7,13 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None):
if codomain is None:
codomain = self.get_codomain(domain)
else:
if not self.check_codomain(domain, codomain):
raise ValueError("ERROR: incompatible codomain!")
super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None:
if nifty_configuration['fft_module'] == 'pyfftw':
self._transform = FFTW(domain, codomain)
elif nifty_configuration['fft_module'] == 'gfft' or \
nifty_configuration['fft_module'] == 'gfft_dummy':
elif (nifty_configuration['fft_module'] == 'gfft' or
nifty_configuration['fft_module'] == 'gfft_dummy'):
self._transform = \
GFFT(domain,
codomain,
......@@ -73,7 +69,9 @@ class RGRGTransformation(Transformation):
distances = 1 / (np.array(domain.shape) *
np.array(domain.distances))
if dtype is None:
dtype = np.complex
# create a definitely complex dtype from the dtype of domain
one = domain.dtype.type(1)
dtype = np.dtype(type(one + 1j))
new_space = RGSpace(domain.shape,
zerocenter=zerocenter,
......@@ -86,7 +84,7 @@ class RGRGTransformation(Transformation):
@staticmethod
def check_codomain(domain, codomain):
if not isinstance(domain, RGSpace):
raise TypeError('ERROR: domain must be a RGSpace')
raise TypeError('ERROR: domain is not a RGSpace')
if codomain is None:
return False
......@@ -101,6 +99,11 @@ class RGRGTransformation(Transformation):
if domain.harmonic == codomain.harmonic:
return False
if codomain.harmonic and not issubclass(codomain.dtype.type,
np.complexfloating):
about.warnings.cprint(
"WARNING: codomain is harmonic but dtype is real.")
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(
np.absolute(np.array(domain.shape) *
......
# -*- coding: utf-8 -*-
import abc
import numpy as np
import nifty.nifty_utilities as utilities
from transformation import Transformation
class SlicingTransformation(Transformation):
def transform(self, val, axes=None, **kwargs):
return_shape = np.array(val.shape)
return_shape[list(axes)] = self.codomain.shape
return_shape = tuple(return_shape)
return_val = None
for slice_list in utilities.get_slice_list(val.shape, axes):
if return_val is None:
return_val = val.copy_empty(dtype=self.codomain.dtype,
global_shape=return_shape)
data = val[slice_list]
data = data.get_full_data()
data = self._transformation_of_slice(data)
return_val[slice_list] = data
return return_val
@abc.abstractmethod
def _transformation_of_slice(self, inp):
raise NotImplementedError
import abc
class Transformation(object):
"""
A generic transformation which defines a static check_codomain
method for all transforms.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, domain, codomain=None, module=None):
pass
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
elif self.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Incompatible codomain!")
@staticmethod
def get_codomain(domain, dtype=None, zerocenter=None, **kwargs):
raise NotImplementedError
@staticmethod
def check_codomain(domain, codomain):
raise NotImplementedError
def transform(self, val, axes=None, **kwargs):
raise NotImplementedError
class _TransformationCache(object):
def __init__(self):
self.cache = {}
def create(self, transformation_class, domain, codomain, module):
key = domain.__hash__() ^ ((codomain.__hash__()/111) ^
(module.__hash__())/179)
if key not in self.cache:
self.cache[key] = transformation_class(domain, codomain, module)
return self.cache[key]
TransformationCache = _TransformationCache()
from nifty.spaces import RGSpace, GLSpace, HPSpace, LMSpace
from rgrgtransformation import RGRGTransformation
from gllmtransformation import GLLMTransformation
from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation
class _TransformationFactory(object):
"""
Transform factory which generates transform objects
"""
def __init__(self):
# cache for storing the transform objects
self.cache = {}
def _get_transform(self, domain, codomain, module):
if isinstance(domain, RGSpace):
if isinstance(codomain, RGSpace):
return RGRGTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
elif isinstance(domain, GLSpace):
if isinstance(codomain, LMSpace):
return GLLMTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
elif isinstance(domain, HPSpace):
if isinstance(codomain, LMSpace):
return HPLMTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
elif isinstance(domain, LMSpace):
if isinstance(codomain, GLSpace):
return LMGLTransformation(domain, codomain, module)
elif isinstance(codomain, HPSpace):
return LMHPTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
else:
raise ValueError('ERROR: unknown domain')
def create(self, domain, codomain, module=None):
key = domain.__hash__() ^ ((codomain.__hash__()/111) ^
(module.__hash__())/179)
if key not in self.cache:
self.cache[key] = self._get_transform(domain, codomain, module)
return self.cache[key]
TransformationFactory = _TransformationFactory()
......@@ -70,7 +70,7 @@ class GLSpace(Space):
# ---Overwritten properties and methods---
def __init__(self, nlat=2, nlon=3, dtype=np.dtype('float')):
def __init__(self, nlat=2, nlon=None, dtype=np.dtype('float')):
"""
Sets the attributes for a gl_space class instance.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment