Commit d7258a63 authored by Martin Reinecke's avatar Martin Reinecke

adjust to new pocketfft interface

parent 00b8f07a
Pipeline #51090 failed with stages
in 5 minutes and 23 seconds
......@@ -15,8 +15,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .utilities import iscomplextype
import numpy as np
import pypocketfft
_nthreads = 1
......@@ -31,81 +29,14 @@ def set_nthreads(nthr):
_nthreads = nthr
# FIXME this should not be necessary ... no one should call a complex FFT
# with a float array.
def _make_complex(a):
if a.dtype in (np.complex64, np.complex128):
return a
if a.dtype == np.float64:
return a.astype(np.complex128)
if a.dtype == np.float32:
return a.astype(np.complex64)
raise NotImplementedError
def fftn(a, axes=None):
return pypocketfft.fftn(_make_complex(a), axes=axes, nthreads=_nthreads)
def rfftn(a, axes=None):
return pypocketfft.rfftn(a, axes=axes, nthreads=_nthreads)
return pypocketfft.c2c(a, axes=axes, nthreads=_nthreads)
def ifftn(a, axes=None):
# FIXME this is a temporary fix and can be done more elegantly
if axes is None:
fct = 1./a.size
fct = 1./, axes))
return pypocketfft.ifftn(_make_complex(a), axes=axes, fct=fct,
return pypocketfft.c2c(a, axes=axes, inorm=2, forward=False,
def hartley(a, axes=None):
return pypocketfft.hartley2(a, axes=axes, nthreads=_nthreads)
# Do a real-to-complex forward FFT and return the _full_ output array
def my_fftn_r2c(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if iscomplextype(a.dtype):
raise TypeError("Transform requires real-valued input arrays.")
tmp = rfftn(a, axes=axes)
def _fill_complex_array(tmp, res, axes):
if axes is None:
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
res[tuple(slice1)] = tmp
def _fill_upper_half_complex(tmp, res, axes):
lastaxis = axes[-1]
nlast = res.shape[lastaxis]
ntmplast = tmp.shape[lastaxis]
nrem = nlast - ntmplast
slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
# np.conjugate(tmp[slice2], out=res[slice1])
res[tuple(slice1)] = np.conjugate(tmp[tuple(slice2)])
for i, ax in enumerate(axes[:-1]):
dim1 = tuple([slice(None)]*ax + [slice(0, 1)])
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half_complex(tmp[dim1], res[dim1], axes2)
_fill_upper_half_complex(tmp, res, axes)
return res
return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
def my_fftn(a, axes=None):
return fftn(a, axes=axes)
return pypocketfft.genuine_hartley(a, axes=axes, nthreads=_nthreads)
......@@ -202,7 +202,7 @@ class HartleyOperator(LinearOperator):
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = fft.my_fftn_r2c(ldat, axes=rem_axes)
ldat = fft.fftn(ldat, axes=rem_axes)
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],
......@@ -211,7 +211,7 @@ class HartleyOperator(LinearOperator):
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = fft.my_fftn(ldat2, axes=(1,))
ldat2 = fft.fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment