Commit 0d492a58 authored by Jait Dixit's avatar Jait Dixit

WIP: Move transforms to transformator

parent 1cad6829
......@@ -93,4 +93,4 @@ from demos import get_demo_dir
from pickling import _pickle_method, _unpickle_method
#import pyximport; pyximport.install(pyimport = True)
from transforms import tf as transformator
......@@ -34,7 +34,7 @@ setup(name="ift_nifty",
packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm",
"nifty.operators", "nifty.dummys", "nifty.field_types",
"nifty.config", "nifty.power"],
"nifty.config", "nifty.power", "nifty.transforms"],
package_dir={"nifty": ""},
from transform_factory import TransformFactory
tf = TransformFactory()
This diff is collapsed.
import numpy as np
from transform import FFT
from d2o import distributed_data_object
import nifty.nifty_utilities as utilities
class GFFT(FFT):
The gfft pendant of a fft object.
fft_module_name : String
Switch between the gfft module used: 'gfft' and 'gfft_dummy'
def __init__(self, domain, codomain, fft_module):
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
def transform(self, val, axes=None, **kwargs):
The gfft transform function.
val : numpy.ndarray or distributed_data_object
The value-array of the field which is supposed to
be transformed.
axes : None or tuple
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object):
temp_inp = val.get_full_data()
temp_inp = val
# Cast input datatype to codomain's dtype
temp_inp = temp_inp.astype(np.complex128, copy=False)
# Array for storing the result
return_val = None
for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
# don't copy the whole data array
if slice_list == [slice(None, None)]:
inp = temp_inp
# initialize the return_val object if needed
if return_val is None:
return_val = np.empty_like(temp_inp)
inp = temp_inp[slice_list]
inp = self.fft_machine.gfft(
ftmachine='fft' if self.codomain.harmonic else 'ifft',
bool, self.domain.paradict['zerocenter']
bool, self.codomain.paradict['zerocenter']
if slice_list == [slice(None, None)]:
return_val = inp
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=self.codomain.dtype)
new_val.set_full_data(return_val, copy=False)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if self.domain.paradict['complexity'] == 0:
new_val.hermitian = True
return_val = new_val
return_val = return_val.astype(self.codomain.dtype, copy=False)
return return_val
class FFT(object):
A generic fft object without any implementation.
def __init__(self):
def transform(self, val, domain, codomain, axes, **kwargs):
A generic ff-transform function.
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
raise NotImplementedError
from nifty.rg import RGSpace
from nifty.lm import GLSpace, HPSpace, LMSpace
from nifty.config import dependency_injector as gdi
from gfft import GFFT
from fftw import FFTW
class TransformFactory(object):
Transform factory which generates transform objects
def __init__(self):
# cache for storing the transform objects
self.cache = {}
def _get_transform_override(self, domain, codomain, module):
Please register or sign in to reply
if module == 'gfft':
return GFFT(domain, codomain, gdi.get('gfft'))
elif module == 'fftw':
return FFTW(domain, codomain)
elif module == 'gfft_dummmy':
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
def _get_transform(self, domain, codomain):
if isinstance(domain, RGSpace) and isinstance(codomain, RGSpace):
# fftw -> gfft -> gfft_dummy
if gdi.get('fftw') is None:
if gdi.get('gfft') is None:
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
return GFFT(domain, codomain, gdi.get('gfft'))
return FFTW(domain, codomain)
def create(self, domain, codomain, module=None):
key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
(179 * module.__hash__()))
if key not in self.cache:
if module is None:
self.cache[key] = self._get_transform(domain, codomain)
self.cache[key] = self._get_transform_override(domain,
return self.cache[key]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment