Commit 38c6be56 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Add axes functionality to GFFT

- GFFT now checks if the axes keyword is provided and only applies FFT to those axes.
parent dad0df03
Pipeline #2802 skipped
# -*- coding: utf-8 -*-
import numpy as np
from itertools import product
from nifty.d2o import distributed_data_object
from nifty.keepers import global_dependency_injector as gdi
......@@ -433,39 +434,71 @@ class GFFT(FFT):
result : np.ndarray
Fourier-transformed pendant of the input field.
"""
if codomain.harmonic:
ftmachine = "fft"
else:
ftmachine = "ifft"
# if the input is a distributed_data_object, extract the data
# GFFT is dumb. The entire input needs to be present on the node.
if isinstance(val, distributed_data_object):
d2oQ = True
temp = val.get_full_data()
else:
d2oQ = False
temp = val
# transform and return
if (domain.dtype == np.float64):
temp = self.fft_machine.gfft(
temp.astype(np.complex128),
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=bool(
codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False
)
if domain.dtype == np.float64:
temp = temp.astype(np.complex128)
if axes:
# Result is generated and stored in a local numpy array
res = np.empty_like(temp)
# Generate iterables for the entire length of each axes, which
# we need to loop over.
axes_iterables = [
None
if x in axes else range(y) for x, y in enumerate(val.shape)
]
for index in product(*axes_iterables):
slice_list = [
index[axis]
if axis else slice(None, None) for x in len(temp.shape)
]
# Extract input from the generated slice
inp = temp[slice_list]
if domain.dtype == np.float64:
inp = inp.astype(np.complex128)
inp = self.fft_machine.gfft(
inp,
in_ax=[],
out_ax=[],
ftmachine='fft' if codomain.harmonic else 'ifft',
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=bool(
codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False
)
res[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(res)
# If the values living in domain are purely real,
# the result of the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
val = new_val
else:
val = res
else:
temp = self.fft_machine.gfft(
temp,
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
ftmachine='fft' if codomain.harmonic else 'ifft',
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=bool(
......@@ -475,15 +508,16 @@ class GFFT(FFT):
alpha=-1,
verbose=False
)
if d2oQ:
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(temp)
# If the values living in domain are purely real, the
# result of the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
val = new_val
else:
val = temp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(temp)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
val = new_val
else:
val = temp
return val
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment