Commit cd1951d6 authored by Jait Dixit's avatar Jait Dixit

WIP: TransformationOperator

- Create skeleton for TransformationOperator
- Minor bug fixes in Field, FFTTransform and RGRGTransformations and
  __init__.py for endomorphic_operator
parent d57b2f02
......@@ -52,11 +52,12 @@ from nifty_utilities import *
from field_types import FieldType,\
FieldArray
from operators import *
from spaces import *
from transformations import *
from operators import *
from demos import get_demo_dir
#import pyximport; pyximport.install(pyimport = True)
from transformations import *
......@@ -40,6 +40,7 @@ class Field(object):
start=start)
self.dtype = self._infer_dtype(dtype=dtype,
val=val,
domain=self.domain,
field_type=self.field_type)
......
......@@ -25,6 +25,8 @@ from linear_operator import LinearOperator
from endomorphic_operator import EndomorphicOperator
from transformation_operator import TransformationOperator
from nifty_operators import operator,\
diagonal_operator,\
power_operator,\
......
# -*- coding: utf-8 -*-
from endmorphic_operator import EndomorphicOperator
from endomorphic_operator import EndomorphicOperator
# -*- coding: utf-8 -*-
from linear_operator import LinearOperator
from linear_operator_paradict import LinearOperatorParadict
from transformation_operator import TransformationOperator
\ No newline at end of file
from nifty.config import about
import nifty.nifty_utilities as utilities
from nifty.operators.linear_operator import LinearOperator
from nifty.transformations import TransformationFactory
class TransformationOperator(LinearOperator):
def __init__(self, domain=(), field_type=(), target=(),
field_type_target=(), implemented=True):
super(TransformationOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
if self.domain == ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator needs a single space as '
'input domain.'
))
else:
if len(self.domain) > 1:
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator accepts only a single'
'space as input'
))
if self.field_type != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator field-type has to be an '
'empty tuple.'
))
# currently not sanitizing the target
self._target = self._parse_domain(target)
self._field_type_target = self._parse_field_type(field_type_target)
if self.field_type_target != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self._forward_transformation = TransformationFactory.create(
self.domain[0], self.target[0]
)
self._inverse_transformation = TransformationFactory.create(
self.target[0], self.domain[0]
)
@property
def target(self):
return self._target
@property
def field_type_target(self):
return self._field_type_target
def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._forward_transformation.transform(x.val, axes=spaces)
def _inverse_times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._inverse_transformation.transform(x.val, axes=spaces)
......@@ -16,7 +16,7 @@ from nifty.config import about,\
dependency_injector as gdi
from lm_space_paradict import LMSpaceParadict
from nifty.nifty_power_indices import lm_power_indices
# from nifty.nifty_power_indices import lm_power_indices
from nifty.nifty_random import random
gl = gdi.get('libsharp_wrapper_gl')
......
......@@ -359,8 +359,7 @@ class FFTW(Transform):
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
transform_shape=val.shape,
# TODO: check why inp.shape doesn't work
transform_shape=inp.shape,
**kwargs
)
......@@ -437,10 +436,6 @@ class FFTW(Transform):
val, axes, **kwargs
)
# If domain is purely real, the result of the FFT is hermitian
if self.domain.paradict['complexity'] == 0:
return_val.hermitian = True
return return_val
......@@ -636,10 +631,6 @@ class GFFT(Transform):
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if self.domain.paradict['complexity'] == 0:
new_val.hermitian = True
return_val = new_val
else:
return_val = return_val.astype(self.codomain.dtype, copy=False)
......
......@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation):
"""
if self._transform.codomain.harmonic:
# correct for forward fft
val = self._transform.domain.calc_weight(val, power=1)
val = self._transform.domain.weight(val, power=1)
# Perform the transformation
Tval = self._transform.transform(val, axes, **kwargs)
if not self._transform.codomain.harmonic:
# correct for inverse fft
Tval = self._transform.codomain.calc_weight(Tval, power=-1)
Tval = self._transform.codomain.weight(Tval, power=-1)
return Tval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment