From 39530b9df29af6617bc6acaff099cf05b0f8ccd3 Mon Sep 17 00:00:00 2001 From: Jait Dixit <jait.dixit@tum.de> Date: Tue, 23 Aug 2016 14:46:53 +0200 Subject: [PATCH] WIP: FFTOperator - Add get_default_codomain function in utilities - Move all transforms to a subfolder inside the fft_operator - Correct weight method calls in transform method of each transformation --- nifty/__init__.py | 2 - nifty/nifty_utilities.py | 35 ++++++++++++---- nifty/operators/__init__.py | 2 +- nifty/operators/fft_operator/__init__.py | 2 + .../fft_operator.py} | 40 +++++++++++-------- .../fft_operator}/transformations/__init__.py | 4 +- .../transformations/gllmtransformation.py | 2 +- .../transformations/hplmtransformation.py | 2 +- .../transformations/lmgltransformation.py | 2 +- .../transformations/lmhptransformation.py | 2 +- .../transformations/rg_transforms.py | 0 .../transformations/rgrgtransformation.py | 4 +- .../transformations/transformation.py | 0 .../transformations/transformation_factory.py | 0 .../transformation_operator/__init__.py | 1 - test/test_nifty_transforms.py | 14 +++---- 16 files changed, 68 insertions(+), 44 deletions(-) create mode 100644 nifty/operators/fft_operator/__init__.py rename nifty/operators/{transformation_operator/transformation_operator.py => fft_operator/fft_operator.py} (78%) rename nifty/{ => operators/fft_operator}/transformations/__init__.py (80%) rename nifty/{ => operators/fft_operator}/transformations/gllmtransformation.py (98%) rename nifty/{ => operators/fft_operator}/transformations/hplmtransformation.py (98%) rename nifty/{ => operators/fft_operator}/transformations/lmgltransformation.py (98%) rename nifty/{ => operators/fft_operator}/transformations/lmhptransformation.py (98%) rename nifty/{ => operators/fft_operator}/transformations/rg_transforms.py (100%) rename nifty/{ => operators/fft_operator}/transformations/rgrgtransformation.py (98%) rename nifty/{ => operators/fft_operator}/transformations/transformation.py (100%) rename nifty/{ => operators/fft_operator}/transformations/transformation_factory.py (100%) delete mode 100644 nifty/operators/transformation_operator/__init__.py diff --git a/nifty/__init__.py b/nifty/__init__.py index fbfb90ca7..06e1af705 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -54,8 +54,6 @@ from field_types import FieldType,\ from spaces import * -from transformations import * - from operators import * from demos import get_demo_dir diff --git a/nifty/nifty_utilities.py b/nifty/nifty_utilities.py index 2bafae965..6da736a2a 100644 --- a/nifty/nifty_utilities.py +++ b/nifty/nifty_utilities.py @@ -41,14 +41,14 @@ def get_slice_list(shape, axes): axes(axis) does not match shape.") ) axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)] - axes_iterables =\ + axes_iterables = \ [range(y) for x, y in enumerate(shape) if x not in axes] for index in product(*axes_iterables): it_iter = iter(index) slice_list = [ next(it_iter) if axis else slice(None, None) for axis in axes_select - ] + ] yield slice_list else: yield [slice(None, None)] @@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None): # The fixed points of the point inversion must not be avaraged. # Hence one must multiply them again with sqrt(0.5) # -> Get the middle index of the array - mid_index = np.array(x.shape, dtype=np.int)//2 + mid_index = np.array(x.shape, dtype=np.int) // 2 dimensions = mid_index.size # Use ndindex to iterate over all combinations of zeros and the # mid_index in order to correct all fixed points. @@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None): ndlist = [2 if i in axes else 1 for i in xrange(dimensions)] ndlist = tuple(ndlist) for i in np.ndindex(ndlist): - temp_index = tuple(i*mid_index) + temp_index = tuple(i * mid_index) x[temp_index] *= np.sqrt(0.5) try: x.hermitian = True @@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes): # calculate the number of dimensions the input array has dimensions = len(x.shape) # prepare the slicing object which will be used for mirroring - slice_primitive = [slice(None), ]*dimensions + slice_primitive = [slice(None), ] * dimensions # copy the input data y = x.copy() @@ -208,8 +208,9 @@ def field_map(ishape, function, *args): # with ishape (3,4,3) and (3,4,1) def get_clipped(w, ind): w_shape = np.array(np.shape(w)) - get_tuple = tuple(np.clip(ind, 0, w_shape-1)) + get_tuple = tuple(np.clip(ind, 0, w_shape - 1)) return w[get_tuple] + result = np.empty_like(args[0]) for i in xrange(reduce(lambda x, y: x * y, result.shape)): ii = np.unravel_index(i, result.shape) @@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length): axis = tuple(int(item) for item in axis) except(TypeError): if np.isscalar(axis): - axis = (int(axis), ) + axis = (int(axis),) else: raise TypeError( "ERROR: Could not convert axis-input to tuple of ints") @@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length): # assert that all entries are elements in [0, length] for elem in axis: - assert(0 <= elem < length) + assert (0 <= elem < length) return axis @@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None): return real_bincount + imag_bincount else: return x.bincount(weights=weights, minlength=minlength) + + +def get_default_codomain(domain): + from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace + from nifty.operators.fft_operator.transformations import RGRGTransformation, \ + HPLMTransformation, GLLMTransformation, LMGLTransformation + + if isinstance(domain, RGSpace): + return RGRGTransformation.get_codomain(domain) + elif isinstance(domain, HPSpace): + return HPLMTransformation.get_codomain(domain) + elif isinstance(domain, GLSpace): + return GLLMTransformation.get_codomain(domain) + elif isinstance(domain, LMSpace): + # TODO: get the preferred transformation path from config + return LMGLTransformation.get_codomain(domain) + else: + raise TypeError(about._errors.cstring('ERROR: unknown domain')) diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index dfd5359cd..8e947122b 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -25,7 +25,7 @@ from linear_operator import LinearOperator from endomorphic_operator import EndomorphicOperator -from transformation_operator import TransformationOperator +from fft_operator import * from nifty_operators import operator,\ diagonal_operator,\ diff --git a/nifty/operators/fft_operator/__init__.py b/nifty/operators/fft_operator/__init__.py new file mode 100644 index 000000000..14eb7b8f9 --- /dev/null +++ b/nifty/operators/fft_operator/__init__.py @@ -0,0 +1,2 @@ +from transformations import * +from fft_operator import FFTOperator \ No newline at end of file diff --git a/nifty/operators/transformation_operator/transformation_operator.py b/nifty/operators/fft_operator/fft_operator.py similarity index 78% rename from nifty/operators/transformation_operator/transformation_operator.py rename to nifty/operators/fft_operator/fft_operator.py index 168d36ff2..01278de67 100644 --- a/nifty/operators/transformation_operator/transformation_operator.py +++ b/nifty/operators/fft_operator/fft_operator.py @@ -1,15 +1,18 @@ from nifty.config import about import nifty.nifty_utilities as utilities from nifty.operators.linear_operator import LinearOperator -from nifty.transformations import TransformationFactory +from transformations import TransformationFactory -class TransformationOperator(LinearOperator): +class FFTOperator(LinearOperator): + + # ---Overwritten properties and methods--- + def __init__(self, domain=(), field_type=(), target=(), field_type_target=(), implemented=True): - super(TransformationOperator, self).__init__(domain=domain, - field_type=field_type, - implemented=implemented) + super(FFTOperator, self).__init__(domain=domain, + field_type=field_type, + implemented=implemented) if self.domain == (): raise TypeError(about._errors.cstring( @@ -19,8 +22,8 @@ class TransformationOperator(LinearOperator): else: if len(self.domain) > 1: raise TypeError(about._errors.cstring( - 'ERROR: TransformationOperator accepts only a single' - 'space as input' + 'ERROR: TransformationOperator accepts only a single ' + 'space as input domain.' )) if self.field_type != (): @@ -30,7 +33,9 @@ class TransformationOperator(LinearOperator): )) # currently not sanitizing the target - self._target = self._parse_domain(target) + self._target = self._parse_domain( + utilities.get_default_codomain(self.domain[0]) + ) self._field_type_target = self._parse_field_type(field_type_target) if self.field_type_target != (): @@ -47,14 +52,6 @@ class TransformationOperator(LinearOperator): self.target[0], self.domain[0] ) - @property - def target(self): - return self._target - - @property - def field_type_target(self): - return self._field_type_target - def _times(self, x, spaces, types): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) @@ -64,3 +61,14 @@ class TransformationOperator(LinearOperator): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) return self._inverse_transformation.transform(x.val, axes=spaces) + + # ---Mandatory properties and methods--- + + @property + def target(self): + return self._target + + @property + def field_type_target(self): + return self._field_type_target + diff --git a/nifty/transformations/__init__.py b/nifty/operators/fft_operator/transformations/__init__.py similarity index 80% rename from nifty/transformations/__init__.py rename to nifty/operators/fft_operator/transformations/__init__.py index 35d89b2cb..42e6101b0 100644 --- a/nifty/transformations/__init__.py +++ b/nifty/operators/fft_operator/transformations/__init__.py @@ -1,9 +1,7 @@ - from rgrgtransformation import RGRGTransformation from gllmtransformation import GLLMTransformation from hplmtransformation import HPLMTransformation from lmgltransformation import LMGLTransformation from lmhptransformation import LMHPTransformation -from transformation_factory import TransformationFactory - +from transformation_factory import TransformationFactory \ No newline at end of file diff --git a/nifty/transformations/gllmtransformation.py b/nifty/operators/fft_operator/transformations/gllmtransformation.py similarity index 98% rename from nifty/transformations/gllmtransformation.py rename to nifty/operators/fft_operator/transformations/gllmtransformation.py index d8a646cd4..73abac3b1 100644 --- a/nifty/transformations/gllmtransformation.py +++ b/nifty/operators/fft_operator/transformations/gllmtransformation.py @@ -87,7 +87,7 @@ class GLLMTransformation(Transformation): """ if self.domain.discrete: - val = self.domain.calc_weight(val, power=-0.5) + val = self.domain.weight(val, power=-0.5, axes=axes) # shorthands for transform parameters nlat = self.domain.paradict['nlat'] diff --git a/nifty/transformations/hplmtransformation.py b/nifty/operators/fft_operator/transformations/hplmtransformation.py similarity index 98% rename from nifty/transformations/hplmtransformation.py rename to nifty/operators/fft_operator/transformations/hplmtransformation.py index 1de37b676..d1da272ef 100644 --- a/nifty/transformations/hplmtransformation.py +++ b/nifty/operators/fft_operator/transformations/hplmtransformation.py @@ -85,7 +85,7 @@ class HPLMTransformation(Transformation): niter = kwargs['niter'] if 'niter' in kwargs else 0 if self.domain.discrete: - val = self.domain.calc_weight(val, power=-0.5) + val = self.domain.weight(val, power=-0.5, axes=axes) # shorthands for transform parameters lmax = self.codomain.paradict['lmax'] diff --git a/nifty/transformations/lmgltransformation.py b/nifty/operators/fft_operator/transformations/lmgltransformation.py similarity index 98% rename from nifty/transformations/lmgltransformation.py rename to nifty/operators/fft_operator/transformations/lmgltransformation.py index c188fbe02..5d54c9bf4 100644 --- a/nifty/transformations/lmgltransformation.py +++ b/nifty/operators/fft_operator/transformations/lmgltransformation.py @@ -131,7 +131,7 @@ class LMGLTransformation(Transformation): # re-weight if discrete if self.codomain.discrete: - val = self.codomain.calc_weight(val, power=0.5) + val = self.codomain.weight(val, power=0.5, axes=axes) if isinstance(val, distributed_data_object): new_val = val.copy_empty(dtype=self.codomain.dtype) diff --git a/nifty/transformations/lmhptransformation.py b/nifty/operators/fft_operator/transformations/lmhptransformation.py similarity index 98% rename from nifty/transformations/lmhptransformation.py rename to nifty/operators/fft_operator/transformations/lmhptransformation.py index c9f00bab5..f2692ba7b 100644 --- a/nifty/transformations/lmhptransformation.py +++ b/nifty/operators/fft_operator/transformations/lmhptransformation.py @@ -114,7 +114,7 @@ class LMHPTransformation(Transformation): # re-weight if discrete if self.codomain.discrete: - val = self.codomain.calc_weight(val, power=0.5) + val = self.codomain.weight(val, power=0.5, axes=axes) if isinstance(val, distributed_data_object): new_val = val.copy_empty(dtype=self.codomain.dtype) diff --git a/nifty/transformations/rg_transforms.py b/nifty/operators/fft_operator/transformations/rg_transforms.py similarity index 100% rename from nifty/transformations/rg_transforms.py rename to nifty/operators/fft_operator/transformations/rg_transforms.py diff --git a/nifty/transformations/rgrgtransformation.py b/nifty/operators/fft_operator/transformations/rgrgtransformation.py similarity index 98% rename from nifty/transformations/rgrgtransformation.py rename to nifty/operators/fft_operator/transformations/rgrgtransformation.py index b643eda0f..758d08d15 100644 --- a/nifty/transformations/rgrgtransformation.py +++ b/nifty/operators/fft_operator/transformations/rgrgtransformation.py @@ -126,13 +126,13 @@ class RGRGTransformation(Transformation): """ if self._transform.codomain.harmonic: # correct for forward fft - val = self._transform.domain.weight(val, power=1) + val = self._transform.domain.weight(val, power=1, axes=axes) # Perform the transformation Tval = self._transform.transform(val, axes, **kwargs) if not self._transform.codomain.harmonic: # correct for inverse fft - Tval = self._transform.codomain.weight(Tval, power=-1) + Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes) return Tval diff --git a/nifty/transformations/transformation.py b/nifty/operators/fft_operator/transformations/transformation.py similarity index 100% rename from nifty/transformations/transformation.py rename to nifty/operators/fft_operator/transformations/transformation.py diff --git a/nifty/transformations/transformation_factory.py b/nifty/operators/fft_operator/transformations/transformation_factory.py similarity index 100% rename from nifty/transformations/transformation_factory.py rename to nifty/operators/fft_operator/transformations/transformation_factory.py diff --git a/nifty/operators/transformation_operator/__init__.py b/nifty/operators/transformation_operator/__init__.py deleted file mode 100644 index 474fd83d0..000000000 --- a/nifty/operators/transformation_operator/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from transformation_operator import TransformationOperator \ No newline at end of file diff --git a/test/test_nifty_transforms.py b/test/test_nifty_transforms.py index aa3582983..5579442e2 100644 --- a/test/test_nifty_transforms.py +++ b/test/test_nifty_transforms.py @@ -1,15 +1,15 @@ -import numpy as np -from numpy.testing import assert_equal, assert_almost_equal, assert_raises +import itertools +import unittest +import d2o +import numpy as np +from nifty.rg.rg_space import gc as RG_GC from nose_parameterized import parameterized -import unittest -import itertools +from numpy.testing import assert_raises from nifty import RGSpace, LMSpace, HPSpace, GLSpace from nifty import transformator -from nifty.transformations.rgrgtransformation import RGRGTransformation -from nifty.rg.rg_space import gc as RG_GC -import d2o +from nifty.operators.fft_operator.transformations.rgrgtransformation import RGRGTransformation ############################################################################### -- GitLab