Commit 39530b9d by Jait Dixit

### WIP: FFTOperator

```- Add get_default_codomain function in utilities
- Move all transforms to a subfolder inside the fft_operator
- Correct weight method calls in transform method of each transformation```
parent cd1951d6
 ... ... @@ -54,8 +54,6 @@ from field_types import FieldType,\ from spaces import * from transformations import * from operators import * from demos import get_demo_dir ... ...
 ... ... @@ -41,14 +41,14 @@ def get_slice_list(shape, axes): axes(axis) does not match shape.") ) axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)] axes_iterables =\ axes_iterables = \ [range(y) for x, y in enumerate(shape) if x not in axes] for index in product(*axes_iterables): it_iter = iter(index) slice_list = [ next(it_iter) if axis else slice(None, None) for axis in axes_select ] ] yield slice_list else: yield [slice(None, None)] ... ... @@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None): # The fixed points of the point inversion must not be avaraged. # Hence one must multiply them again with sqrt(0.5) # -> Get the middle index of the array mid_index = np.array(x.shape, dtype=np.int)//2 mid_index = np.array(x.shape, dtype=np.int) // 2 dimensions = mid_index.size # Use ndindex to iterate over all combinations of zeros and the # mid_index in order to correct all fixed points. ... ... @@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None): ndlist = [2 if i in axes else 1 for i in xrange(dimensions)] ndlist = tuple(ndlist) for i in np.ndindex(ndlist): temp_index = tuple(i*mid_index) temp_index = tuple(i * mid_index) x[temp_index] *= np.sqrt(0.5) try: x.hermitian = True ... ... @@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes): # calculate the number of dimensions the input array has dimensions = len(x.shape) # prepare the slicing object which will be used for mirroring slice_primitive = [slice(None), ]*dimensions slice_primitive = [slice(None), ] * dimensions # copy the input data y = x.copy() ... ... @@ -208,8 +208,9 @@ def field_map(ishape, function, *args): # with ishape (3,4,3) and (3,4,1) def get_clipped(w, ind): w_shape = np.array(np.shape(w)) get_tuple = tuple(np.clip(ind, 0, w_shape-1)) get_tuple = tuple(np.clip(ind, 0, w_shape - 1)) return w[get_tuple] result = np.empty_like(args[0]) for i in xrange(reduce(lambda x, y: x * y, result.shape)): ii = np.unravel_index(i, result.shape) ... ... @@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length): axis = tuple(int(item) for item in axis) except(TypeError): if np.isscalar(axis): axis = (int(axis), ) axis = (int(axis),) else: raise TypeError( "ERROR: Could not convert axis-input to tuple of ints") ... ... @@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length): # assert that all entries are elements in [0, length] for elem in axis: assert(0 <= elem < length) assert (0 <= elem < length) return axis ... ... @@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None): return real_bincount + imag_bincount else: return x.bincount(weights=weights, minlength=minlength) def get_default_codomain(domain): from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace from nifty.operators.fft_operator.transformations import RGRGTransformation, \ HPLMTransformation, GLLMTransformation, LMGLTransformation if isinstance(domain, RGSpace): return RGRGTransformation.get_codomain(domain) elif isinstance(domain, HPSpace): return HPLMTransformation.get_codomain(domain) elif isinstance(domain, GLSpace): return GLLMTransformation.get_codomain(domain) elif isinstance(domain, LMSpace): # TODO: get the preferred transformation path from config return LMGLTransformation.get_codomain(domain) else: raise TypeError(about._errors.cstring('ERROR: unknown domain'))
 ... ... @@ -25,7 +25,7 @@ from linear_operator import LinearOperator from endomorphic_operator import EndomorphicOperator from transformation_operator import TransformationOperator from fft_operator import * from nifty_operators import operator,\ diagonal_operator,\ ... ...
 from transformations import * from fft_operator import FFTOperator \ No newline at end of file
 from nifty.config import about import nifty.nifty_utilities as utilities from nifty.operators.linear_operator import LinearOperator from nifty.transformations import TransformationFactory from transformations import TransformationFactory class TransformationOperator(LinearOperator): class FFTOperator(LinearOperator): # ---Overwritten properties and methods--- def __init__(self, domain=(), field_type=(), target=(), field_type_target=(), implemented=True): super(TransformationOperator, self).__init__(domain=domain, field_type=field_type, implemented=implemented) super(FFTOperator, self).__init__(domain=domain, field_type=field_type, implemented=implemented) if self.domain == (): raise TypeError(about._errors.cstring( ... ... @@ -19,8 +22,8 @@ class TransformationOperator(LinearOperator): else: if len(self.domain) > 1: raise TypeError(about._errors.cstring( 'ERROR: TransformationOperator accepts only a single' 'space as input' 'ERROR: TransformationOperator accepts only a single ' 'space as input domain.' )) if self.field_type != (): ... ... @@ -30,7 +33,9 @@ class TransformationOperator(LinearOperator): )) # currently not sanitizing the target self._target = self._parse_domain(target) self._target = self._parse_domain( utilities.get_default_codomain(self.domain[0]) ) self._field_type_target = self._parse_field_type(field_type_target) if self.field_type_target != (): ... ... @@ -47,14 +52,6 @@ class TransformationOperator(LinearOperator): self.target[0], self.domain[0] ) @property def target(self): return self._target @property def field_type_target(self): return self._field_type_target def _times(self, x, spaces, types): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) ... ... @@ -64,3 +61,14 @@ class TransformationOperator(LinearOperator): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) return self._inverse_transformation.transform(x.val, axes=spaces) # ---Mandatory properties and methods--- @property def target(self): return self._target @property def field_type_target(self): return self._field_type_target
 from rgrgtransformation import RGRGTransformation from gllmtransformation import GLLMTransformation from hplmtransformation import HPLMTransformation from lmgltransformation import LMGLTransformation from lmhptransformation import LMHPTransformation from transformation_factory import TransformationFactory from transformation_factory import TransformationFactory \ No newline at end of file
 ... ... @@ -87,7 +87,7 @@ class GLLMTransformation(Transformation): """ if self.domain.discrete: val = self.domain.calc_weight(val, power=-0.5) val = self.domain.weight(val, power=-0.5, axes=axes) # shorthands for transform parameters nlat = self.domain.paradict['nlat'] ... ...
 ... ... @@ -85,7 +85,7 @@ class HPLMTransformation(Transformation): niter = kwargs['niter'] if 'niter' in kwargs else 0 if self.domain.discrete: val = self.domain.calc_weight(val, power=-0.5) val = self.domain.weight(val, power=-0.5, axes=axes) # shorthands for transform parameters lmax = self.codomain.paradict['lmax'] ... ...
 ... ... @@ -131,7 +131,7 @@ class LMGLTransformation(Transformation): # re-weight if discrete if self.codomain.discrete: val = self.codomain.calc_weight(val, power=0.5) val = self.codomain.weight(val, power=0.5, axes=axes) if isinstance(val, distributed_data_object): new_val = val.copy_empty(dtype=self.codomain.dtype) ... ...
 ... ... @@ -114,7 +114,7 @@ class LMHPTransformation(Transformation): # re-weight if discrete if self.codomain.discrete: val = self.codomain.calc_weight(val, power=0.5) val = self.codomain.weight(val, power=0.5, axes=axes) if isinstance(val, distributed_data_object): new_val = val.copy_empty(dtype=self.codomain.dtype) ... ...
 ... ... @@ -126,13 +126,13 @@ class RGRGTransformation(Transformation): """ if self._transform.codomain.harmonic: # correct for forward fft val = self._transform.domain.weight(val, power=1) val = self._transform.domain.weight(val, power=1, axes=axes) # Perform the transformation Tval = self._transform.transform(val, axes, **kwargs) if not self._transform.codomain.harmonic: # correct for inverse fft Tval = self._transform.codomain.weight(Tval, power=-1) Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes) return Tval
 from transformation_operator import TransformationOperator \ No newline at end of file
 import numpy as np from numpy.testing import assert_equal, assert_almost_equal, assert_raises import itertools import unittest import d2o import numpy as np from nifty.rg.rg_space import gc as RG_GC from nose_parameterized import parameterized import unittest import itertools from numpy.testing import assert_raises from nifty import RGSpace, LMSpace, HPSpace, GLSpace from nifty import transformator from nifty.transformations.rgrgtransformation import RGRGTransformation from nifty.rg.rg_space import gc as RG_GC import d2o from nifty.operators.fft_operator.transformations.rgrgtransformation import RGRGTransformation ############################################################################### ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!