Commit 27bcdb96 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: FFTW transform method

- Refactor transform() according to discussions
- Add checks for validity of axes
parent e66fac9a
Pipeline #4668 skipped
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np import numpy as np
from d2o import distributed_data_object from d2o import distributed_data_object, STRATEGIES
from nifty.config import dependency_injector as gdi from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
...@@ -300,63 +300,86 @@ class FFTW(FFT): ...@@ -300,63 +300,86 @@ class FFTW(FFT):
result : np.ndarray or distributed_data_object result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
# If the input is a numpy array we transform it locally # If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object): if not isinstance(val, distributed_data_object):
current_info = self._get_transform_info(domain, codomain, current_info = self._get_transform_info(domain, codomain,
is_local=True, is_local=True,
**kwargs) **kwargs)
# Copy data for local manipulation # local transform doesn't apply transforms inplace
return_val = self._local_transform(val, current_info, axes, return_val = self._local_transform(val, current_info, axes,
domain, codomain) domain, codomain)
else: else:
if val.distribution_strategy is 'fftw': if val.distribution_strategy in STRATEGIES['slicing']:
# Create a local transform if the axes is anything other than if axes is None or set(axes) == set(range(len(val.shape))):
# None or (0,) explicitly. current_info = self._get_transform_info(domain, codomain,
current_info = self._get_transform_info( **kwargs)
domain, codomain,
is_local=(False if axes is None or axes == (0,) else True), return_val = val.copy_empty(global_shape=val.shape,
**kwargs dtype=codomain.dtype)
)
# Compute transform for the local data
# Create the object which will be returned result = self._local_transform(
return_val = val.copy_empty(global_shape=val.shape, val.get_local_data(copy=False),
dtype=codomain.dtype) current_info, axes,
domain, codomain
# Extract local data )
local_val = val.get_local_data(copy=False)
# Insert result into the return_val
# Intermediate storage for individual slices, if transforming return_val.set_local_data(data=result, copy=False)
# only along (0,) elif 0 in axes:
temp_val = None current_info = self._get_transform_info(domain, codomain,
**kwargs)
for slice_list in utilities.get_slice_list(local_val.shape, # Repack val to fftw
axes): new_val = val.copy(distribution_strategy='fftw',
if slice_list == [slice(None, None)]: copy=False)
inp = local_val return_val = new_val.copy_empty(global_shape=new_val.shape,
else: dtype=codomain.dtype)
if temp_val is None: # Extract local data
temp_val = np.empty_like(temp_val) local_val = val.get_local_data(copy=False)
inp = temp_val[slice_list]
# Intermediate storage for individual slices
# Transfom the data temp_val = np.empty_like(local_val)
result = self._local_transform(inp, current_info,
axes, domain, codomain) for slice_list in utilities.get_slice_list(local_val.shape,
axes):
if slice_list == [slice(None, None)]: inp = local_val[slice_list]
return_val.set_local_data(data=result, copy=False) result = self._local_transform(inp, current_info,
else: axes, domain, codomain)
temp_val[slice_list] = result temp_val[slice_list] = result
if temp_val is not None: return_val.set_local_data(data=temp_val, copy=False)
# Repack to original distribution strategy
return_val = return_val.copy(
distribution_strategy=val.distribution_strategy,
copy=False
)
else:
current_info = self._get_transform_info(domain, codomain,
is_local=True,
**kwargs)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Compute transform for the local data
result = self._local_transform(
val.get_local_data(copy=False),
current_info, axes,
domain, codomain
)
# Insert result inplace in return_val
return_val.set_local_data(data=result, copy=False) return_val.set_local_data(data=result, copy=False)
# If the values living in domain are purely real, the # If domain is purely real, the result of the FFT is hermitian
# resulting FFT is hermitian
# TODO: in case of multiple spaces using one d2o array, there
# isn't anymore 'the' hermitianity
if domain.paradict['complexity'] == 0: if domain.paradict['complexity'] == 0:
return_val.hermitian = True return_val.hermitian = True
else: else:
new_val = val.copy(distribution_strategy='fftw', new_val = val.copy(distribution_strategy='fftw',
copy=False) copy=False)
...@@ -370,6 +393,11 @@ class FFTW(FFT): ...@@ -370,6 +393,11 @@ class FFTW(FFT):
distribution_strategy=val.distribution_strategy, distribution_strategy=val.distribution_strategy,
copy=False copy=False
) )
# If domain is purely real, the result of the FFT is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
return return_val return return_val
...@@ -509,6 +537,11 @@ class GFFT(FFT): ...@@ -509,6 +537,11 @@ class GFFT(FFT):
result : np.ndarray or distributed_data_object result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
# GFFT doesn't accept d2o objects as input. Consolidate data from # GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding. # all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment