Commit 28b305d3 authored by theos's avatar theos
Browse files

Merge branch 'feature/field_multiple_space' of gitlab.mpcdf.mpg.de:ift/NIFTy...

Merge branch 'feature/field_multiple_space' of gitlab.mpcdf.mpg.de:ift/NIFTy into feature/field_multiple_space
parents f79918d5 162f4066
......@@ -52,11 +52,10 @@ from nifty_utilities import *
from field_types import FieldType,\
FieldArray
from operators import *
from spaces import *
from operators import *
from demos import get_demo_dir
#import pyximport; pyximport.install(pyimport = True)
from transformations import *
......@@ -40,6 +40,7 @@ class Field(object):
start=start)
self.dtype = self._infer_dtype(dtype=dtype,
val=val,
domain=self.domain,
field_type=self.field_type)
......
......@@ -41,14 +41,14 @@ def get_slice_list(shape, axes):
axes(axis) does not match shape.")
)
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\
axes_iterables = \
[range(y) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
]
yield slice_list
else:
yield [slice(None, None)]
......@@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None):
# The fixed points of the point inversion must not be avaraged.
# Hence one must multiply them again with sqrt(0.5)
# -> Get the middle index of the array
mid_index = np.array(x.shape, dtype=np.int)//2
mid_index = np.array(x.shape, dtype=np.int) // 2
dimensions = mid_index.size
# Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points.
......@@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None):
ndlist = [2 if i in axes else 1 for i in xrange(dimensions)]
ndlist = tuple(ndlist)
for i in np.ndindex(ndlist):
temp_index = tuple(i*mid_index)
temp_index = tuple(i * mid_index)
x[temp_index] *= np.sqrt(0.5)
try:
x.hermitian = True
......@@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes):
# calculate the number of dimensions the input array has
dimensions = len(x.shape)
# prepare the slicing object which will be used for mirroring
slice_primitive = [slice(None), ]*dimensions
slice_primitive = [slice(None), ] * dimensions
# copy the input data
y = x.copy()
......@@ -208,8 +208,9 @@ def field_map(ishape, function, *args):
# with ishape (3,4,3) and (3,4,1)
def get_clipped(w, ind):
w_shape = np.array(np.shape(w))
get_tuple = tuple(np.clip(ind, 0, w_shape-1))
get_tuple = tuple(np.clip(ind, 0, w_shape - 1))
return w[get_tuple]
result = np.empty_like(args[0])
for i in xrange(reduce(lambda x, y: x * y, result.shape)):
ii = np.unravel_index(i, result.shape)
......@@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length):
axis = tuple(int(item) for item in axis)
except(TypeError):
if np.isscalar(axis):
axis = (int(axis), )
axis = (int(axis),)
else:
raise TypeError(
"ERROR: Could not convert axis-input to tuple of ints")
......@@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length):
# assert that all entries are elements in [0, length]
for elem in axis:
assert(0 <= elem < length)
assert (0 <= elem < length)
return axis
......@@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None):
return real_bincount + imag_bincount
else:
return x.bincount(weights=weights, minlength=minlength)
def get_default_codomain(domain):
from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace
from nifty.operators.fft_operator.transformations import RGRGTransformation, \
HPLMTransformation, GLLMTransformation, LMGLTransformation
if isinstance(domain, RGSpace):
return RGRGTransformation.get_codomain(domain)
elif isinstance(domain, HPSpace):
return HPLMTransformation.get_codomain(domain)
elif isinstance(domain, GLSpace):
return GLLMTransformation.get_codomain(domain)
elif isinstance(domain, LMSpace):
# TODO: get the preferred transformation path from config
return LMGLTransformation.get_codomain(domain)
else:
raise TypeError(about._errors.cstring('ERROR: unknown domain'))
......@@ -25,6 +25,8 @@ from linear_operator import LinearOperator
from endomorphic_operator import EndomorphicOperator
from fft_operator import *
from nifty_operators import operator,\
diagonal_operator,\
power_operator,\
......
from transformations import *
from fft_operator import FFTOperator
\ No newline at end of file
from nifty.config import about
import nifty.nifty_utilities as utilities
from nifty.operators.linear_operator import LinearOperator
from transformations import TransformationFactory
class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=(),
field_type_target=(), implemented=True):
super(FFTOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
if self.domain == ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator needs a single space as '
'input domain.'
))
else:
if len(self.domain) > 1:
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if self.field_type != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator field-type has to be an '
'empty tuple.'
))
# currently not sanitizing the target
self._target = self._parse_domain(
utilities.get_default_codomain(self.domain[0])
)
self._field_type_target = self._parse_field_type(field_type_target)
if self.field_type_target != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self._forward_transformation = TransformationFactory.create(
self.domain[0], self.target[0]
)
self._inverse_transformation = TransformationFactory.create(
self.target[0], self.domain[0]
)
def adjoint_times(self, x, spaces=None, types=None):
return self.inverse_times(x, spaces, types)
def adjoint_inverse_times(self, x, spaces=None, types=None):
return self.times(x, spaces, types)
def inverse_adjoint_times(self, x, spaces=None, types=None):
return self.times(x, spaces, types)
def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._forward_transformation.transform(x.val, axes=spaces)
def _inverse_times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._inverse_transformation.transform(x.val, axes=spaces)
# ---Mandatory properties and methods---
@property
def target(self):
return self._target
@property
def field_type_target(self):
return self._field_type_target
from rgrgtransformation import RGRGTransformation
from gllmtransformation import GLLMTransformation
from hplmtransformation import HPLMTransformation
from lmgltransformation import LMGLTransformation
from lmhptransformation import LMHPTransformation
from transformation_factory import TransformationFactory
from transformation_factory import TransformationFactory
\ No newline at end of file
......@@ -87,7 +87,7 @@ class GLLMTransformation(Transformation):
"""
if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5)
val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters
nlat = self.domain.paradict['nlat']
......
......@@ -85,7 +85,7 @@ class HPLMTransformation(Transformation):
niter = kwargs['niter'] if 'niter' in kwargs else 0
if self.domain.discrete:
val = self.domain.calc_weight(val, power=-0.5)
val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters
lmax = self.codomain.paradict['lmax']
......
......@@ -131,7 +131,7 @@ class LMGLTransformation(Transformation):
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5)
val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
......
......@@ -114,7 +114,7 @@ class LMHPTransformation(Transformation):
# re-weight if discrete
if self.codomain.discrete:
val = self.codomain.calc_weight(val, power=0.5)
val = self.codomain.weight(val, power=0.5, axes=axes)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
......
......@@ -359,8 +359,7 @@ class FFTW(Transform):
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
transform_shape=val.shape,
# TODO: check why inp.shape doesn't work
transform_shape=inp.shape,
**kwargs
)
......@@ -437,10 +436,6 @@ class FFTW(Transform):
val, axes, **kwargs
)
# If domain is purely real, the result of the FFT is hermitian
if self.domain.paradict['complexity'] == 0:
return_val.hermitian = True
return return_val
......@@ -636,10 +631,6 @@ class GFFT(Transform):
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if self.domain.paradict['complexity'] == 0:
new_val.hermitian = True
return_val = new_val
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
......
......@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation):
"""
if self._transform.codomain.harmonic:
# correct for forward fft
val = self._transform.domain.calc_weight(val, power=1)
val = self._transform.domain.weight(val, power=1, axes=axes)
# Perform the transformation
Tval = self._transform.transform(val, axes, **kwargs)
if not self._transform.codomain.harmonic:
# correct for inverse fft
Tval = self._transform.codomain.calc_weight(Tval, power=-1)
Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
return Tval
# -*- coding: utf-8 -*-
from linear_operator import LinearOperator
from linear_operator_paradict import LinearOperatorParadict
......@@ -16,7 +16,7 @@ from nifty.config import about,\
dependency_injector as gdi
from lm_space_paradict import LMSpaceParadict
from nifty.nifty_power_indices import lm_power_indices
# from nifty.nifty_power_indices import lm_power_indices
from nifty.nifty_random import random
gl = gdi.get('libsharp_wrapper_gl')
......
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal, assert_raises
import itertools
import unittest
import d2o
import numpy as np
from nifty.rg.rg_space import gc as RG_GC
from nose_parameterized import parameterized
import unittest
import itertools
from numpy.testing import assert_raises
from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator
from nifty.transformations.rgrgtransformation import RGRGTransformation
from nifty.rg.rg_space import gc as RG_GC
import d2o
from nifty.operators.fft_operator.transformations.rgrgtransformation import RGRGTransformation
###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment