Commit 162f4066 authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'transformation_operator' into 'feature/field_multiple_space'

TransformationOperator



See merge request !23
parents d57b2f02 3265ea51
...@@ -52,11 +52,10 @@ from nifty_utilities import * ...@@ -52,11 +52,10 @@ from nifty_utilities import *
from field_types import FieldType,\ from field_types import FieldType,\
FieldArray FieldArray
from operators import *
from spaces import * from spaces import *
from operators import *
from demos import get_demo_dir from demos import get_demo_dir
#import pyximport; pyximport.install(pyimport = True) #import pyximport; pyximport.install(pyimport = True)
from transformations import *
...@@ -40,6 +40,7 @@ class Field(object): ...@@ -40,6 +40,7 @@ class Field(object):
start=start) start=start)
self.dtype = self._infer_dtype(dtype=dtype, self.dtype = self._infer_dtype(dtype=dtype,
val=val,
domain=self.domain, domain=self.domain,
field_type=self.field_type) field_type=self.field_type)
......
...@@ -41,14 +41,14 @@ def get_slice_list(shape, axes): ...@@ -41,14 +41,14 @@ def get_slice_list(shape, axes):
axes(axis) does not match shape.") axes(axis) does not match shape.")
) )
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)] axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\ axes_iterables = \
[range(y) for x, y in enumerate(shape) if x not in axes] [range(y) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables): for index in product(*axes_iterables):
it_iter = iter(index) it_iter = iter(index)
slice_list = [ slice_list = [
next(it_iter) next(it_iter)
if axis else slice(None, None) for axis in axes_select if axis else slice(None, None) for axis in axes_select
] ]
yield slice_list yield slice_list
else: else:
yield [slice(None, None)] yield [slice(None, None)]
...@@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None): ...@@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None):
# The fixed points of the point inversion must not be avaraged. # The fixed points of the point inversion must not be avaraged.
# Hence one must multiply them again with sqrt(0.5) # Hence one must multiply them again with sqrt(0.5)
# -> Get the middle index of the array # -> Get the middle index of the array
mid_index = np.array(x.shape, dtype=np.int)//2 mid_index = np.array(x.shape, dtype=np.int) // 2
dimensions = mid_index.size dimensions = mid_index.size
# Use ndindex to iterate over all combinations of zeros and the # Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points. # mid_index in order to correct all fixed points.
...@@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None): ...@@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None):
ndlist = [2 if i in axes else 1 for i in xrange(dimensions)] ndlist = [2 if i in axes else 1 for i in xrange(dimensions)]
ndlist = tuple(ndlist) ndlist = tuple(ndlist)
for i in np.ndindex(ndlist): for i in np.ndindex(ndlist):
temp_index = tuple(i*mid_index) temp_index = tuple(i * mid_index)
x[temp_index] *= np.sqrt(0.5) x[temp_index] *= np.sqrt(0.5)
try: try:
x.hermitian = True x.hermitian = True
...@@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes): ...@@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes):
# calculate the number of dimensions the input array has # calculate the number of dimensions the input array has
dimensions = len(x.shape) dimensions = len(x.shape)
# prepare the slicing object which will be used for mirroring # prepare the slicing object which will be used for mirroring
slice_primitive = [slice(None), ]*dimensions slice_primitive = [slice(None), ] * dimensions
# copy the input data # copy the input data
y = x.copy() y = x.copy()
...@@ -208,8 +208,9 @@ def field_map(ishape, function, *args): ...@@ -208,8 +208,9 @@ def field_map(ishape, function, *args):
# with ishape (3,4,3) and (3,4,1) # with ishape (3,4,3) and (3,4,1)
def get_clipped(w, ind): def get_clipped(w, ind):
w_shape = np.array(np.shape(w)) w_shape = np.array(np.shape(w))
get_tuple = tuple(np.clip(ind, 0, w_shape-1)) get_tuple = tuple(np.clip(ind, 0, w_shape - 1))
return w[get_tuple] return w[get_tuple]
result = np.empty_like(args[0]) result = np.empty_like(args[0])
for i in xrange(reduce(lambda x, y: x * y, result.shape)): for i in xrange(reduce(lambda x, y: x * y, result.shape)):
ii = np.unravel_index(i, result.shape) ii = np.unravel_index(i, result.shape)
...@@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length): ...@@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length):
axis = tuple(int(item) for item in axis) axis = tuple(int(item) for item in axis)
except(TypeError): except(TypeError):
if np.isscalar(axis): if np.isscalar(axis):
axis = (int(axis), ) axis = (int(axis),)
else: else:
raise TypeError( raise TypeError(
"ERROR: Could not convert axis-input to tuple of ints") "ERROR: Could not convert axis-input to tuple of ints")
...@@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length): ...@@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length):
# assert that all entries are elements in [0, length] # assert that all entries are elements in [0, length]
for elem in axis: for elem in axis:
assert(0 <= elem < length) assert (0 <= elem < length)
return axis return axis
...@@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None): ...@@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None):
return real_bincount + imag_bincount return real_bincount + imag_bincount
else: else:
return x.bincount(weights=weights, minlength=minlength) return x.bincount(weights=weights, minlength=minlength)
def get_default_codomain(domain):
from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace
from nifty.operators.fft_operator.transformations import RGRGTransformation, \
HPLMTransformation, GLLMTransformation, LMGLTransformation
if isinstance(domain, RGSpace):
return RGRGTransformation.get_codomain(domain)
elif isinstance(domain, HPSpace):
return HPLMTransformation.get_codomain(domain)
elif isinstance(domain, GLSpace):
return GLLMTransformation.get_codomain(domain)
elif isinstance(domain, LMSpace):
# TODO: get the preferred transformation path from config
return LMGLTransformation.get_codomain(domain)
else:
raise TypeError(about._errors.cstring('ERROR: unknown domain'))
...@@ -25,6 +25,8 @@ from linear_operator import LinearOperator ...@@ -25,6 +25,8 @@ from linear_operator import LinearOperator
from endomorphic_operator import EndomorphicOperator from endomorphic_operator import EndomorphicOperator
from fft_operator import *
from nifty_operators import operator,\ from nifty_operators import operator,\
diagonal_operator,\ diagonal_operator,\
power_operator,\ power_operator,\
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from endmorphic_operator import EndomorphicOperator from endomorphic_operator import EndomorphicOperator
from transformations import *
from fft_operator import FFTOperator
\ No newline at end of file
from nifty.config import about
import nifty.nifty_utilities as utilities
from nifty.operators.linear_operator import LinearOperator
from transformations import TransformationFactory
class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=(),
field_type_target=(), implemented=True):
super(FFTOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
if self.domain == ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator needs a single space as '
'input domain.'
))
else:
if len(self.domain) > 1:
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if self.field_type != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator field-type has to be an '
'empty tuple.'
))
# currently not sanitizing the target
self._target = self._parse_domain(
utilities.get_default_codomain(self.domain[0])
)
self._field_type_target = self._parse_field_type(field_type_target)
if self.field_type_target != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self._forward_transformation = TransformationFactory.create(
self.domain[0], self.target[0]
)
self._inverse_transformation = TransformationFactory.create(
self.target[0], self.domain[0]
)
def adjoint_times(self, x, spaces=None, types=None):
return self.inverse_times(x, spaces, types)
def adjoint_inverse_times(self, x, spaces=None, types=None):
return self.times(x, spaces, types)
def inverse_adjoint_times(self, x, spaces=None, types=None):
return self.times(x, spaces, types)
def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._forward_transformation.transform(x.val, axes=spaces)
def _inverse_times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._inverse_transformation.transform(x.val, axes=spaces)
# ---Mandatory properties and methods---
@property
def target(self):
return self._target
@property
def field_type_target(self):
return self._field_type_target
from rgrgtransformation import RGRGTransformation from rgrgtransformation import RGRGTransformation
from gllmtransformation import GLLMTransformation from gllmtransformation import GLLMTransformation
from hplmtransformation import HPLMTransformation from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation from lmhptransformation import LMHPTransformation
from transformation_factory import TransformationFactory from transformation_factory import TransformationFactory
\ No newline at end of file
...@@ -87,7 +87,7 @@ class GLLMTransformation(Transformation): ...@@ -87,7 +87,7 @@ class GLLMTransformation(Transformation):
""" """
if self.domain.discrete: if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5) val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters # shorthands for transform parameters
nlat = self.domain.paradict['nlat'] nlat = self.domain.paradict['nlat']
......
...@@ -85,7 +85,7 @@ class HPLMTransformation(Transformation): ...@@ -85,7 +85,7 @@ class HPLMTransformation(Transformation):
niter = kwargs['niter'] if 'niter' in kwargs else 0 niter = kwargs['niter'] if 'niter' in kwargs else 0
if self.domain.discrete: if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5) val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters # shorthands for transform parameters
lmax = self.codomain.paradict['lmax'] lmax = self.codomain.paradict['lmax']
......
...@@ -131,7 +131,7 @@ class LMGLTransformation(Transformation): ...@@ -131,7 +131,7 @@ class LMGLTransformation(Transformation):
# re-weight if discrete # re-weight if discrete
if self.codomain.discrete: if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5) val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype) new_val = val.copy_empty(dtype=self.codomain.dtype)
......
...@@ -114,7 +114,7 @@ class LMHPTransformation(Transformation): ...@@ -114,7 +114,7 @@ class LMHPTransformation(Transformation):
# re-weight if discrete # re-weight if discrete
if self.codomain.discrete: if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5) val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype) new_val = val.copy_empty(dtype=self.codomain.dtype)
......
...@@ -359,8 +359,7 @@ class FFTW(Transform): ...@@ -359,8 +359,7 @@ class FFTW(Transform):
local_shape=val.local_shape, local_shape=val.local_shape,
local_offset_Q=local_offset_Q, local_offset_Q=local_offset_Q,
is_local=False, is_local=False,
transform_shape=val.shape, transform_shape=inp.shape,
# TODO: check why inp.shape doesn't work
**kwargs **kwargs
) )
...@@ -437,10 +436,6 @@ class FFTW(Transform): ...@@ -437,10 +436,6 @@ class FFTW(Transform):
val, axes, **kwargs val, axes, **kwargs
) )
# If domain is purely real, the result of the FFT is hermitian
if self.domain.paradict['complexity'] == 0:
return_val.hermitian = True
return return_val return return_val
...@@ -636,10 +631,6 @@ class GFFT(Transform): ...@@ -636,10 +631,6 @@ class GFFT(Transform):
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype) new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False) new_val.set_full_data(return_val, copy=False)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if self.domain.paradict['complexity'] == 0:
new_val.hermitian = True
return_val = new_val return_val = new_val
else: else:
return_val = return_val.astype(self.codomain.dtype, copy=False) return_val = return_val.astype(self.codomain.dtype, copy=False)
......
...@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation): ...@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation):
""" """
if self._transform.codomain.harmonic: if self._transform.codomain.harmonic:
# correct for forward fft # correct for forward fft
val = self._transform.domain.calc_weight(val, power=1) val = self._transform.domain.weight(val, power=1, axes=axes)
# Perform the transformation # Perform the transformation
Tval = self._transform.transform(val, axes, **kwargs) Tval = self._transform.transform(val, axes, **kwargs)
if not self._transform.codomain.harmonic: if not self._transform.codomain.harmonic:
# correct for inverse fft # correct for inverse fft
Tval = self._transform.codomain.calc_weight(Tval, power=-1) Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
return Tval return Tval
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from linear_operator import LinearOperator from linear_operator import LinearOperator
from linear_operator_paradict import LinearOperatorParadict
...@@ -16,7 +16,7 @@ from nifty.config import about,\ ...@@ -16,7 +16,7 @@ from nifty.config import about,\
dependency_injector as gdi dependency_injector as gdi
from lm_space_paradict import LMSpaceParadict from lm_space_paradict import LMSpaceParadict
from nifty.nifty_power_indices import lm_power_indices # from nifty.nifty_power_indices import lm_power_indices
from nifty.nifty_random import random from nifty.nifty_random import random
gl = gdi.get('libsharp_wrapper_gl') gl = gdi.get('libsharp_wrapper_gl')
......
import numpy as np import itertools
from numpy.testing import assert_equal, assert_almost_equal, assert_raises import unittest
import d2o
import numpy as np
from nifty.rg.rg_space import gc as RG_GC
from nose_parameterized import parameterized from nose_parameterized import parameterized
import unittest from numpy.testing import assert_raises
import itertools
from nifty import RGSpace, LMSpace, HPSpace, GLSpace from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator from nifty import transformator
from nifty.transformations.rgrgtransformation import RGRGTransformation from nifty.operators.fft_operator.transformations.rgrgtransformation import RGRGTransformation
from nifty.rg.rg_space import gc as RG_GC
import d2o
############################################################################### ###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment