diff --git a/nifty/__init__.py b/nifty/__init__.py index fbfb90ca74e31b79ec163017d448e633dbb34c94..06e1af705d7ad7104aa2caf5f12abc59652aeb40 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -54,8 +54,6 @@ from field_types import FieldType,\ from spaces import * -from transformations import * - from operators import * from demos import get_demo_dir diff --git a/nifty/nifty_utilities.py b/nifty/nifty_utilities.py index 2bafae96500ea71a0e9183976ebd682088e80784..6da736a2a08c7de7499498ca3ccfffe61d3ce7d5 100644 --- a/nifty/nifty_utilities.py +++ b/nifty/nifty_utilities.py @@ -41,14 +41,14 @@ def get_slice_list(shape, axes): axes(axis) does not match shape.") ) axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)] - axes_iterables =\ + axes_iterables = \ [range(y) for x, y in enumerate(shape) if x not in axes] for index in product(*axes_iterables): it_iter = iter(index) slice_list = [ next(it_iter) if axis else slice(None, None) for axis in axes_select - ] + ] yield slice_list else: yield [slice(None, None)] @@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None): # The fixed points of the point inversion must not be avaraged. # Hence one must multiply them again with sqrt(0.5) # -> Get the middle index of the array - mid_index = np.array(x.shape, dtype=np.int)//2 + mid_index = np.array(x.shape, dtype=np.int) // 2 dimensions = mid_index.size # Use ndindex to iterate over all combinations of zeros and the # mid_index in order to correct all fixed points. @@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None): ndlist = [2 if i in axes else 1 for i in xrange(dimensions)] ndlist = tuple(ndlist) for i in np.ndindex(ndlist): - temp_index = tuple(i*mid_index) + temp_index = tuple(i * mid_index) x[temp_index] *= np.sqrt(0.5) try: x.hermitian = True @@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes): # calculate the number of dimensions the input array has dimensions = len(x.shape) # prepare the slicing object which will be used for mirroring - slice_primitive = [slice(None), ]*dimensions + slice_primitive = [slice(None), ] * dimensions # copy the input data y = x.copy() @@ -208,8 +208,9 @@ def field_map(ishape, function, *args): # with ishape (3,4,3) and (3,4,1) def get_clipped(w, ind): w_shape = np.array(np.shape(w)) - get_tuple = tuple(np.clip(ind, 0, w_shape-1)) + get_tuple = tuple(np.clip(ind, 0, w_shape - 1)) return w[get_tuple] + result = np.empty_like(args[0]) for i in xrange(reduce(lambda x, y: x * y, result.shape)): ii = np.unravel_index(i, result.shape) @@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length): axis = tuple(int(item) for item in axis) except(TypeError): if np.isscalar(axis): - axis = (int(axis), ) + axis = (int(axis),) else: raise TypeError( "ERROR: Could not convert axis-input to tuple of ints") @@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length): # assert that all entries are elements in [0, length] for elem in axis: - assert(0 <= elem < length) + assert (0 <= elem < length) return axis @@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None): return real_bincount + imag_bincount else: return x.bincount(weights=weights, minlength=minlength) + + +def get_default_codomain(domain): + from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace + from nifty.operators.fft_operator.transformations import RGRGTransformation, \ + HPLMTransformation, GLLMTransformation, LMGLTransformation + + if isinstance(domain, RGSpace): + return RGRGTransformation.get_codomain(domain) + elif isinstance(domain, HPSpace): + return HPLMTransformation.get_codomain(domain) + elif isinstance(domain, GLSpace): + return GLLMTransformation.get_codomain(domain) + elif isinstance(domain, LMSpace): + # TODO: get the preferred transformation path from config + return LMGLTransformation.get_codomain(domain) + else: + raise TypeError(about._errors.cstring('ERROR: unknown domain')) diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index dfd5359cd43d5f4e77ebd8181084c5866b799037..8e947122b701defab8b4cbcda8c63ad8a8c0d0ad 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -25,7 +25,7 @@ from linear_operator import LinearOperator from endomorphic_operator import EndomorphicOperator -from transformation_operator import TransformationOperator +from fft_operator import * from nifty_operators import operator,\ diagonal_operator,\ diff --git a/nifty/operators/fft_operator/__init__.py b/nifty/operators/fft_operator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14eb7b8f94bc993ece0d0cf3cd2accdd3a6c32bd --- /dev/null +++ b/nifty/operators/fft_operator/__init__.py @@ -0,0 +1,2 @@ +from transformations import * +from fft_operator import FFTOperator \ No newline at end of file diff --git a/nifty/operators/transformation_operator/transformation_operator.py b/nifty/operators/fft_operator/fft_operator.py similarity index 78% rename from nifty/operators/transformation_operator/transformation_operator.py rename to nifty/operators/fft_operator/fft_operator.py index 168d36ff28ec7043c8c758bb9e4770bf0d694ac3..01278de677fc0759a3374eb2fdb3a3925f0e80a5 100644 --- a/nifty/operators/transformation_operator/transformation_operator.py +++ b/nifty/operators/fft_operator/fft_operator.py @@ -1,15 +1,18 @@ from nifty.config import about import nifty.nifty_utilities as utilities from nifty.operators.linear_operator import LinearOperator -from nifty.transformations import TransformationFactory +from transformations import TransformationFactory -class TransformationOperator(LinearOperator): +class FFTOperator(LinearOperator): + + # ---Overwritten properties and methods--- + def __init__(self, domain=(), field_type=(), target=(), field_type_target=(), implemented=True): - super(TransformationOperator, self).__init__(domain=domain, - field_type=field_type, - implemented=implemented) + super(FFTOperator, self).__init__(domain=domain, + field_type=field_type, + implemented=implemented) if self.domain == (): raise TypeError(about._errors.cstring( @@ -19,8 +22,8 @@ class TransformationOperator(LinearOperator): else: if len(self.domain) > 1: raise TypeError(about._errors.cstring( - 'ERROR: TransformationOperator accepts only a single' - 'space as input' + 'ERROR: TransformationOperator accepts only a single ' + 'space as input domain.' )) if self.field_type != (): @@ -30,7 +33,9 @@ class TransformationOperator(LinearOperator): )) # currently not sanitizing the target - self._target = self._parse_domain(target) + self._target = self._parse_domain( + utilities.get_default_codomain(self.domain[0]) + ) self._field_type_target = self._parse_field_type(field_type_target) if self.field_type_target != (): @@ -47,14 +52,6 @@ class TransformationOperator(LinearOperator): self.target[0], self.domain[0] ) - @property - def target(self): - return self._target - - @property - def field_type_target(self): - return self._field_type_target - def _times(self, x, spaces, types): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) @@ -64,3 +61,14 @@ class TransformationOperator(LinearOperator): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) return self._inverse_transformation.transform(x.val, axes=spaces) + + # ---Mandatory properties and methods--- + + @property + def target(self): + return self._target + + @property + def field_type_target(self): + return self._field_type_target + diff --git a/nifty/transformations/__init__.py b/nifty/operators/fft_operator/transformations/__init__.py similarity index 80% rename from nifty/transformations/__init__.py rename to nifty/operators/fft_operator/transformations/__init__.py index 35d89b2cbeffbea3f06136136954c95ba0b69e9e..42e6101b05ad827c789d25e3239b4897a50639f3 100644 --- a/nifty/transformations/__init__.py +++ b/nifty/operators/fft_operator/transformations/__init__.py @@ -1,9 +1,7 @@ - from rgrgtransformation import RGRGTransformation from gllmtransformation import GLLMTransformation from hplmtransformation import HPLMTransformation from lmgltransformation import LMGLTransformation from lmhptransformation import LMHPTransformation -from transformation_factory import TransformationFactory - +from transformation_factory import TransformationFactory \ No newline at end of file diff --git a/nifty/transformations/gllmtransformation.py b/nifty/operators/fft_operator/transformations/gllmtransformation.py similarity index 98% rename from nifty/transformations/gllmtransformation.py rename to nifty/operators/fft_operator/transformations/gllmtransformation.py index d8a646cd4508d8165657135cabb7e4469ec227a5..73abac3b1f898f0a4163408a3ce7945d2a666626 100644 --- a/nifty/transformations/gllmtransformation.py +++ b/nifty/operators/fft_operator/transformations/gllmtransformation.py @@ -87,7 +87,7 @@ class GLLMTransformation(Transformation): """ if self.domain.discrete: - val = self.domain.calc_weight(val, power=-0.5) + val = self.domain.weight(val, power=-0.5, axes=axes) # shorthands for transform parameters nlat = self.domain.paradict['nlat'] diff --git a/nifty/transformations/hplmtransformation.py b/nifty/operators/fft_operator/transformations/hplmtransformation.py similarity index 98% rename from nifty/transformations/hplmtransformation.py rename to nifty/operators/fft_operator/transformations/hplmtransformation.py index 1de37b676967d095354ec3bc81445a1e1b1ab9e9..d1da272efb34cf0ed7d56e92baef25c87f26f4e8 100644 --- a/nifty/transformations/hplmtransformation.py +++ b/nifty/operators/fft_operator/transformations/hplmtransformation.py @@ -85,7 +85,7 @@ class HPLMTransformation(Transformation): niter = kwargs['niter'] if 'niter' in kwargs else 0 if self.domain.discrete: - val = self.domain.calc_weight(val, power=-0.5) + val = self.domain.weight(val, power=-0.5, axes=axes) # shorthands for transform parameters lmax = self.codomain.paradict['lmax'] diff --git a/nifty/transformations/lmgltransformation.py b/nifty/operators/fft_operator/transformations/lmgltransformation.py similarity index 98% rename from nifty/transformations/lmgltransformation.py rename to nifty/operators/fft_operator/transformations/lmgltransformation.py index c188fbe023468b0bad0fdc5891dc2ceecb7ec80c..5d54c9bf43e56b7867e29f7f29004133aebe3d69 100644 --- a/nifty/transformations/lmgltransformation.py +++ b/nifty/operators/fft_operator/transformations/lmgltransformation.py @@ -131,7 +131,7 @@ class LMGLTransformation(Transformation): # re-weight if discrete if self.codomain.discrete: - val = self.codomain.calc_weight(val, power=0.5) + val = self.codomain.weight(val, power=0.5, axes=axes) if isinstance(val, distributed_data_object): new_val = val.copy_empty(dtype=self.codomain.dtype) diff --git a/nifty/transformations/lmhptransformation.py b/nifty/operators/fft_operator/transformations/lmhptransformation.py similarity index 98% rename from nifty/transformations/lmhptransformation.py rename to nifty/operators/fft_operator/transformations/lmhptransformation.py index c9f00bab5b874ce18fdc7be356dc3ccc1f506607..f2692ba7b59a2adab294a62af57c89113d8b2b29 100644 --- a/nifty/transformations/lmhptransformation.py +++ b/nifty/operators/fft_operator/transformations/lmhptransformation.py @@ -114,7 +114,7 @@ class LMHPTransformation(Transformation): # re-weight if discrete if self.codomain.discrete: - val = self.codomain.calc_weight(val, power=0.5) + val = self.codomain.weight(val, power=0.5, axes=axes) if isinstance(val, distributed_data_object): new_val = val.copy_empty(dtype=self.codomain.dtype) diff --git a/nifty/transformations/rg_transforms.py b/nifty/operators/fft_operator/transformations/rg_transforms.py similarity index 100% rename from nifty/transformations/rg_transforms.py rename to nifty/operators/fft_operator/transformations/rg_transforms.py diff --git a/nifty/transformations/rgrgtransformation.py b/nifty/operators/fft_operator/transformations/rgrgtransformation.py similarity index 98% rename from nifty/transformations/rgrgtransformation.py rename to nifty/operators/fft_operator/transformations/rgrgtransformation.py index b643eda0fb9723a04dda7dd94fde3f9cf91294e1..758d08d150c8e41075e5dd9edb2f81b4945e781d 100644 --- a/nifty/transformations/rgrgtransformation.py +++ b/nifty/operators/fft_operator/transformations/rgrgtransformation.py @@ -126,13 +126,13 @@ class RGRGTransformation(Transformation): """ if self._transform.codomain.harmonic: # correct for forward fft - val = self._transform.domain.weight(val, power=1) + val = self._transform.domain.weight(val, power=1, axes=axes) # Perform the transformation Tval = self._transform.transform(val, axes, **kwargs) if not self._transform.codomain.harmonic: # correct for inverse fft - Tval = self._transform.codomain.weight(Tval, power=-1) + Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes) return Tval diff --git a/nifty/transformations/transformation.py b/nifty/operators/fft_operator/transformations/transformation.py similarity index 100% rename from nifty/transformations/transformation.py rename to nifty/operators/fft_operator/transformations/transformation.py diff --git a/nifty/transformations/transformation_factory.py b/nifty/operators/fft_operator/transformations/transformation_factory.py similarity index 100% rename from nifty/transformations/transformation_factory.py rename to nifty/operators/fft_operator/transformations/transformation_factory.py diff --git a/nifty/operators/transformation_operator/__init__.py b/nifty/operators/transformation_operator/__init__.py deleted file mode 100644 index 474fd83d00b6100b492e43f1f29e46d511000480..0000000000000000000000000000000000000000 --- a/nifty/operators/transformation_operator/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from transformation_operator import TransformationOperator \ No newline at end of file diff --git a/test/test_nifty_transforms.py b/test/test_nifty_transforms.py index aa358298396c5b0598cf416c12edc42d5708a092..5579442e2657ca1bfa5df8f29d1fbe91823b540c 100644 --- a/test/test_nifty_transforms.py +++ b/test/test_nifty_transforms.py @@ -1,15 +1,15 @@ -import numpy as np -from numpy.testing import assert_equal, assert_almost_equal, assert_raises +import itertools +import unittest +import d2o +import numpy as np +from nifty.rg.rg_space import gc as RG_GC from nose_parameterized import parameterized -import unittest -import itertools +from numpy.testing import assert_raises from nifty import RGSpace, LMSpace, HPSpace, GLSpace from nifty import transformator -from nifty.transformations.rgrgtransformation import RGRGTransformation -from nifty.rg.rg_space import gc as RG_GC -import d2o +from nifty.operators.fft_operator.transformations.rgrgtransformation import RGRGTransformation ###############################################################################