Commit 27bcdb96 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: FFTW transform method

- Refactor transform() according to discussions
- Add checks for validity of axes
parent e66fac9a
Pipeline #4668 skipped
# -*- coding: utf-8 -*-
import numpy as np
from d2o import distributed_data_object
from d2o import distributed_data_object, STRATEGIES
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
......@@ -300,63 +300,86 @@ class FFTW(FFT):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
current_info = self._get_transform_info(domain, codomain,
is_local=True,
**kwargs)
# Copy data for local manipulation
# local transform doesn't apply transforms inplace
return_val = self._local_transform(val, current_info, axes,
domain, codomain)
else:
if val.distribution_strategy is 'fftw':
# Create a local transform if the axes is anything other than
# None or (0,) explicitly.
current_info = self._get_transform_info(
domain, codomain,
is_local=(False if axes is None or axes == (0,) else True),
**kwargs
)
if val.distribution_strategy in STRATEGIES['slicing']:
if axes is None or set(axes) == set(range(len(val.shape))):
current_info = self._get_transform_info(domain, codomain,
**kwargs)
# Create the object which will be returned
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Compute transform for the local data
result = self._local_transform(
val.get_local_data(copy=False),
current_info, axes,
domain, codomain
)
# Insert result into the return_val
return_val.set_local_data(data=result, copy=False)
elif 0 in axes:
current_info = self._get_transform_info(domain, codomain,
**kwargs)
# Repack val to fftw
new_val = val.copy(distribution_strategy='fftw',
copy=False)
return_val = new_val.copy_empty(global_shape=new_val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val.get_local_data(copy=False)
# Intermediate storage for individual slices, if transforming
# only along (0,)
temp_val = None
# Intermediate storage for individual slices
temp_val = np.empty_like(local_val)
for slice_list in utilities.get_slice_list(local_val.shape,
axes):
if slice_list == [slice(None, None)]:
inp = local_val
else:
if temp_val is None:
temp_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
# Transfom the data
inp = local_val[slice_list]
result = self._local_transform(inp, current_info,
axes, domain, codomain)
temp_val[slice_list] = result
if slice_list == [slice(None, None)]:
return_val.set_local_data(data=result, copy=False)
return_val.set_local_data(data=temp_val, copy=False)
# Repack to original distribution strategy
return_val = return_val.copy(
distribution_strategy=val.distribution_strategy,
copy=False
)
else:
temp_val[slice_list] = result
current_info = self._get_transform_info(domain, codomain,
is_local=True,
**kwargs)
if temp_val is not None:
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Compute transform for the local data
result = self._local_transform(
val.get_local_data(copy=False),
current_info, axes,
domain, codomain
)
# Insert result inplace in return_val
return_val.set_local_data(data=result, copy=False)
# If the values living in domain are purely real, the
# resulting FFT is hermitian
# TODO: in case of multiple spaces using one d2o array, there
# isn't anymore 'the' hermitianity
# If domain is purely real, the result of the FFT is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
else:
new_val = val.copy(distribution_strategy='fftw',
copy=False)
......@@ -370,6 +393,11 @@ class FFTW(FFT):
distribution_strategy=val.distribution_strategy,
copy=False
)
# If domain is purely real, the result of the FFT is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
return return_val
......@@ -509,6 +537,11 @@ class GFFT(FFT):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment