Commit 76565174 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Axes keyword in FFT transforms

- GFFT now uses get_slice_list to get slices to iterate over all axes
parent 26862f93
Pipeline #3222 skipped
# -*- coding: utf-8 -*-
import numpy as np
from itertools import product
from nifty.d2o import distributed_data_object
from nifty.keepers import global_dependency_injector as gdi
import nifty.nifty_utilities as utilities
pyfftw = gdi.get('pyfftw')
gfft = gdi.get('gfft')
......@@ -445,29 +445,13 @@ class GFFT(FFT):
if domain.dtype == np.float64:
temp = temp.astype(np.complex128)
# Result is generated and stored in a local numpy array
res = np.empty_like(temp)
if axes:
# Result is generated and stored in a local numpy array
res = np.empty_like(temp)
# Generate iterables for the entire length of each axes, which
# we need to loop over.
axes_iterables = [
None
if x in axes else range(y) for x, y in enumerate(val.shape)
]
for index in product(*axes_iterables):
slice_list = [
index[axis]
if axis else slice(None, None) for x in len(temp.shape)
]
# Extract input from the generated slice
for slice_list in utilities.get_slice_list(temp.shape, axes):
inp = temp[slice_list]
if domain.dtype == np.float64:
inp = inp.astype(np.complex128)
inp = self.fft_machine.gfft(
inp,
in_ax=[],
......@@ -484,19 +468,8 @@ class GFFT(FFT):
)
res[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(res)
# If the values living in domain are purely real,
# the result of the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
val = new_val
else:
val = res
else:
temp = self.fft_machine.gfft(
res = self.fft_machine.gfft(
temp,
in_ax=[],
out_ax=[],
......@@ -511,15 +484,15 @@ class GFFT(FFT):
verbose=False
)
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(temp)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
val = new_val
else:
val = temp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(res)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
val = new_val
else:
val = res
return val
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment