diff --git a/nifty/__init__.py b/nifty/__init__.py index 4166a49fe8c80e5b2507382b4d7163e281656c79..fbfb90ca74e31b79ec163017d448e633dbb34c94 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -52,11 +52,12 @@ from nifty_utilities import * from field_types import FieldType,\ FieldArray -from operators import * - from spaces import * +from transformations import * + +from operators import * + from demos import get_demo_dir #import pyximport; pyximport.install(pyimport = True) -from transformations import * diff --git a/nifty/field.py b/nifty/field.py index 749ecd46b1d8929b4385e20cf4f31ec350e441fe..164b0de1d28cd2cef44daf5c44dbfbb6512f4474 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -40,6 +40,7 @@ class Field(object): start=start) self.dtype = self._infer_dtype(dtype=dtype, + val=val, domain=self.domain, field_type=self.field_type) diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index 2512cd39b6e9b93e8eda48d823e4a74308556186..dfd5359cd43d5f4e77ebd8181084c5866b799037 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -25,6 +25,8 @@ from linear_operator import LinearOperator from endomorphic_operator import EndomorphicOperator +from transformation_operator import TransformationOperator + from nifty_operators import operator,\ diagonal_operator,\ power_operator,\ diff --git a/nifty/operators/endomorphic_operator/__init__.py b/nifty/operators/endomorphic_operator/__init__.py index 46344f49a85e3ee5a916c4501fe5f79164bafbcd..7c2fbe52a736361eeb63f0b851550b8a8c342d4a 100644 --- a/nifty/operators/endomorphic_operator/__init__.py +++ b/nifty/operators/endomorphic_operator/__init__.py @@ -1,3 +1,3 @@ # -*- coding: utf-8 -*- -from endmorphic_operator import EndomorphicOperator +from endomorphic_operator import EndomorphicOperator diff --git a/nifty/operators/linear_operator/__init__.py b/nifty/operators/linear_operator/__init__.py index c9bebe593f080465cce077c56bebf1c64d2cffb2..789d8c9924b5251603004adbbbf9b3936698b2ea 100644 --- a/nifty/operators/linear_operator/__init__.py +++ b/nifty/operators/linear_operator/__init__.py @@ -1,4 +1,3 @@ # -*- coding: utf-8 -*- from linear_operator import LinearOperator -from linear_operator_paradict import LinearOperatorParadict diff --git a/nifty/operators/transformation_operator/__init__.py b/nifty/operators/transformation_operator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..474fd83d00b6100b492e43f1f29e46d511000480 --- /dev/null +++ b/nifty/operators/transformation_operator/__init__.py @@ -0,0 +1 @@ +from transformation_operator import TransformationOperator \ No newline at end of file diff --git a/nifty/operators/transformation_operator/transformation_operator.py b/nifty/operators/transformation_operator/transformation_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..168d36ff28ec7043c8c758bb9e4770bf0d694ac3 --- /dev/null +++ b/nifty/operators/transformation_operator/transformation_operator.py @@ -0,0 +1,66 @@ +from nifty.config import about +import nifty.nifty_utilities as utilities +from nifty.operators.linear_operator import LinearOperator +from nifty.transformations import TransformationFactory + + +class TransformationOperator(LinearOperator): + def __init__(self, domain=(), field_type=(), target=(), + field_type_target=(), implemented=True): + super(TransformationOperator, self).__init__(domain=domain, + field_type=field_type, + implemented=implemented) + + if self.domain == (): + raise TypeError(about._errors.cstring( + 'ERROR: TransformationOperator needs a single space as ' + 'input domain.' + )) + else: + if len(self.domain) > 1: + raise TypeError(about._errors.cstring( + 'ERROR: TransformationOperator accepts only a single' + 'space as input' + )) + + if self.field_type != (): + raise TypeError(about._errors.cstring( + 'ERROR: TransformationOperator field-type has to be an ' + 'empty tuple.' + )) + + # currently not sanitizing the target + self._target = self._parse_domain(target) + self._field_type_target = self._parse_field_type(field_type_target) + + if self.field_type_target != (): + raise TypeError(about._errors.cstring( + 'ERROR: TransformationOperator target field-type has to be an ' + 'empty tuple.' + )) + + self._forward_transformation = TransformationFactory.create( + self.domain[0], self.target[0] + ) + + self._inverse_transformation = TransformationFactory.create( + self.target[0], self.domain[0] + ) + + @property + def target(self): + return self._target + + @property + def field_type_target(self): + return self._field_type_target + + def _times(self, x, spaces, types): + spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) + + return self._forward_transformation.transform(x.val, axes=spaces) + + def _inverse_times(self, x, spaces, types): + spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) + + return self._inverse_transformation.transform(x.val, axes=spaces) diff --git a/nifty/spaces/lm_space/lm_space.py b/nifty/spaces/lm_space/lm_space.py index 245dc8125f5a683c8bb87b2018b390f577d4cec2..89a6811c540e28a01f9e6fd91f1e73b8ca08be75 100644 --- a/nifty/spaces/lm_space/lm_space.py +++ b/nifty/spaces/lm_space/lm_space.py @@ -16,7 +16,7 @@ from nifty.config import about,\ dependency_injector as gdi from lm_space_paradict import LMSpaceParadict -from nifty.nifty_power_indices import lm_power_indices +# from nifty.nifty_power_indices import lm_power_indices from nifty.nifty_random import random gl = gdi.get('libsharp_wrapper_gl') diff --git a/nifty/transformations/rg_transforms.py b/nifty/transformations/rg_transforms.py index 955a1989a368cd6b6f6a70ede1ebed1a4e36758d..202a7e99e64138141c41ea74d74d1d65f9709f4b 100644 --- a/nifty/transformations/rg_transforms.py +++ b/nifty/transformations/rg_transforms.py @@ -359,8 +359,7 @@ class FFTW(Transform): local_shape=val.local_shape, local_offset_Q=local_offset_Q, is_local=False, - transform_shape=val.shape, - # TODO: check why inp.shape doesn't work + transform_shape=inp.shape, **kwargs ) @@ -437,10 +436,6 @@ class FFTW(Transform): val, axes, **kwargs ) - # If domain is purely real, the result of the FFT is hermitian - if self.domain.paradict['complexity'] == 0: - return_val.hermitian = True - return return_val @@ -636,10 +631,6 @@ class GFFT(Transform): if isinstance(val, distributed_data_object): new_val = val.copy_empty(dtype=self.codomain.dtype) new_val.set_full_data(return_val, copy=False) - # If the values living in domain are purely real, the result of - # the fft is hermitian - if self.domain.paradict['complexity'] == 0: - new_val.hermitian = True return_val = new_val else: return_val = return_val.astype(self.codomain.dtype, copy=False) diff --git a/nifty/transformations/rgrgtransformation.py b/nifty/transformations/rgrgtransformation.py index b711fdc23b40dc1aa955114fbf23b613525f3ea4..b643eda0fb9723a04dda7dd94fde3f9cf91294e1 100644 --- a/nifty/transformations/rgrgtransformation.py +++ b/nifty/transformations/rgrgtransformation.py @@ -126,13 +126,13 @@ class RGRGTransformation(Transformation): """ if self._transform.codomain.harmonic: # correct for forward fft - val = self._transform.domain.calc_weight(val, power=1) + val = self._transform.domain.weight(val, power=1) # Perform the transformation Tval = self._transform.transform(val, axes, **kwargs) if not self._transform.codomain.harmonic: # correct for inverse fft - Tval = self._transform.codomain.calc_weight(Tval, power=-1) + Tval = self._transform.codomain.weight(Tval, power=-1) return Tval