Commit 39530b9d authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: FFTOperator

- Add get_default_codomain function in utilities
- Move all transforms to a subfolder inside the fft_operator
- Correct weight method calls in transform method of each transformation
parent cd1951d6
......@@ -54,8 +54,6 @@ from field_types import FieldType,\
from spaces import *
from transformations import *
from operators import *
from demos import get_demo_dir
......
......@@ -41,14 +41,14 @@ def get_slice_list(shape, axes):
axes(axis) does not match shape.")
)
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\
axes_iterables = \
[range(y) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
]
yield slice_list
else:
yield [slice(None, None)]
......@@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None):
# The fixed points of the point inversion must not be avaraged.
# Hence one must multiply them again with sqrt(0.5)
# -> Get the middle index of the array
mid_index = np.array(x.shape, dtype=np.int)//2
mid_index = np.array(x.shape, dtype=np.int) // 2
dimensions = mid_index.size
# Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points.
......@@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None):
ndlist = [2 if i in axes else 1 for i in xrange(dimensions)]
ndlist = tuple(ndlist)
for i in np.ndindex(ndlist):
temp_index = tuple(i*mid_index)
temp_index = tuple(i * mid_index)
x[temp_index] *= np.sqrt(0.5)
try:
x.hermitian = True
......@@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes):
# calculate the number of dimensions the input array has
dimensions = len(x.shape)
# prepare the slicing object which will be used for mirroring
slice_primitive = [slice(None), ]*dimensions
slice_primitive = [slice(None), ] * dimensions
# copy the input data
y = x.copy()
......@@ -208,8 +208,9 @@ def field_map(ishape, function, *args):
# with ishape (3,4,3) and (3,4,1)
def get_clipped(w, ind):
w_shape = np.array(np.shape(w))
get_tuple = tuple(np.clip(ind, 0, w_shape-1))
get_tuple = tuple(np.clip(ind, 0, w_shape - 1))
return w[get_tuple]
result = np.empty_like(args[0])
for i in xrange(reduce(lambda x, y: x * y, result.shape)):
ii = np.unravel_index(i, result.shape)
......@@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length):
axis = tuple(int(item) for item in axis)
except(TypeError):
if np.isscalar(axis):
axis = (int(axis), )
axis = (int(axis),)
else:
raise TypeError(
"ERROR: Could not convert axis-input to tuple of ints")
......@@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length):
# assert that all entries are elements in [0, length]
for elem in axis:
assert(0 <= elem < length)
assert (0 <= elem < length)
return axis
......@@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None):
return real_bincount + imag_bincount
else:
return x.bincount(weights=weights, minlength=minlength)
def get_default_codomain(domain):
from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace
from nifty.operators.fft_operator.transformations import RGRGTransformation, \
HPLMTransformation, GLLMTransformation, LMGLTransformation
if isinstance(domain, RGSpace):
return RGRGTransformation.get_codomain(domain)
elif isinstance(domain, HPSpace):
return HPLMTransformation.get_codomain(domain)
elif isinstance(domain, GLSpace):
return GLLMTransformation.get_codomain(domain)
elif isinstance(domain, LMSpace):
# TODO: get the preferred transformation path from config
return LMGLTransformation.get_codomain(domain)
else:
raise TypeError(about._errors.cstring('ERROR: unknown domain'))
......@@ -25,7 +25,7 @@ from linear_operator import LinearOperator
from endomorphic_operator import EndomorphicOperator
from transformation_operator import TransformationOperator
from fft_operator import *
from nifty_operators import operator,\
diagonal_operator,\
......
from transformations import *
from fft_operator import FFTOperator
\ No newline at end of file
from nifty.config import about
import nifty.nifty_utilities as utilities
from nifty.operators.linear_operator import LinearOperator
from nifty.transformations import TransformationFactory
from transformations import TransformationFactory
class TransformationOperator(LinearOperator):
class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=(),
field_type_target=(), implemented=True):
super(TransformationOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
super(FFTOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
if self.domain == ():
raise TypeError(about._errors.cstring(
......@@ -19,8 +22,8 @@ class TransformationOperator(LinearOperator):
else:
if len(self.domain) > 1:
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator accepts only a single'
'space as input'
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if self.field_type != ():
......@@ -30,7 +33,9 @@ class TransformationOperator(LinearOperator):
))
# currently not sanitizing the target
self._target = self._parse_domain(target)
self._target = self._parse_domain(
utilities.get_default_codomain(self.domain[0])
)
self._field_type_target = self._parse_field_type(field_type_target)
if self.field_type_target != ():
......@@ -47,14 +52,6 @@ class TransformationOperator(LinearOperator):
self.target[0], self.domain[0]
)
@property
def target(self):
return self._target
@property
def field_type_target(self):
return self._field_type_target
def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
......@@ -64,3 +61,14 @@ class TransformationOperator(LinearOperator):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._inverse_transformation.transform(x.val, axes=spaces)
# ---Mandatory properties and methods---
@property
def target(self):
return self._target
@property
def field_type_target(self):
return self._field_type_target
from rgrgtransformation import RGRGTransformation
from gllmtransformation import GLLMTransformation
from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation
from transformation_factory import TransformationFactory
from transformation_factory import TransformationFactory
\ No newline at end of file
......@@ -87,7 +87,7 @@ class GLLMTransformation(Transformation):
"""
if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5)
val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters
nlat = self.domain.paradict['nlat']
......
......@@ -85,7 +85,7 @@ class HPLMTransformation(Transformation):
niter = kwargs['niter'] if 'niter' in kwargs else 0
if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5)
val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters
lmax = self.codomain.paradict['lmax']
......
......@@ -131,7 +131,7 @@ class LMGLTransformation(Transformation):
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5)
val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
......
......@@ -114,7 +114,7 @@ class LMHPTransformation(Transformation):
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5)
val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
......
......@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation):
"""
if self._transform.codomain.harmonic:
# correct for forward fft
val = self._transform.domain.weight(val, power=1)
val = self._transform.domain.weight(val, power=1, axes=axes)
# Perform the transformation
Tval = self._transform.transform(val, axes, **kwargs)
if not self._transform.codomain.harmonic:
# correct for inverse fft
Tval = self._transform.codomain.weight(Tval, power=-1)
Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
return Tval
from transformation_operator import TransformationOperator
\ No newline at end of file
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal, assert_raises
import itertools
import unittest
import d2o
import numpy as np
from nifty.rg.rg_space import gc as RG_GC
from nose_parameterized import parameterized
import unittest
import itertools
from numpy.testing import assert_raises
from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator
from nifty.transformations.rgrgtransformation import RGRGTransformation
from nifty.rg.rg_space import gc as RG_GC
import d2o
from nifty.operators.fft_operator.transformations.rgrgtransformation import RGRGTransformation
###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment