Commit 09f3edac authored by Martin Reinecke's avatar Martin Reinecke
Browse files

turbocharged FFT step 1/n

parent f09cf8f2
Pipeline #21825 passed with stage
in 4 minutes and 16 seconds
......@@ -419,7 +419,7 @@ def redistribute(arr, dist=None, nodist=None):
s_msg = [sbuf, (ssz, sdisp), MPI.BYTE]
r_msg = [rbuf, (rsz, rdisp), MPI.BYTE]
_comm.Alltoallv(s_msg, r_msg)
del sbuf # free memory
if arr._distaxis == 0:
rbuf = rbuf.reshape(local_shape(arr.shape, dist))
arrnew = from_local_data(arr.shape, rbuf, distaxis=dist)
......@@ -436,5 +436,40 @@ def redistribute(arr, dist=None, nodist=None):
return arrnew
def transpose(arr):
if len(arr.shape) != 2 or arr._distaxis != 0:
raise ValueError("bad input")
ssz0 = arr._data.size//arr.shape[1]
ssz = np.empty(ntask, dtype=np.int)
rszall = arr.size//arr.shape[1]*_shareSize(arr.shape[1], ntask, rank)
rbuf = np.empty(rszall, dtype=arr.dtype)
rsz0 = rszall//arr.shape[0]
rsz = np.empty(ntask, dtype=np.int)
sbuf = np.empty(arr._data.size, dtype=arr.dtype)
ofs = 0
for i in range(ntask):
lo, hi = _shareRange(arr.shape[1], ntask, i)
ssz[i] = ssz0*(hi-lo)
sbuf[ofs:ofs+ssz[i]] = arr._data[:,lo:hi].flat
ofs += ssz[i]
rsz[i] = rsz0*_shareSize(arr.shape[0], ntask, i)
ssz *= arr._data.itemsize
rsz *= arr._data.itemsize
sdisp = np.append(0, np.cumsum(ssz[:-1]))
rdisp = np.append(0, np.cumsum(rsz[:-1]))
s_msg = [sbuf, (ssz, sdisp), MPI.BYTE]
r_msg = [rbuf, (rsz, rdisp), MPI.BYTE]
_comm.Alltoallv(s_msg, r_msg)
del sbuf # free memory
arrnew = empty((arr.shape[1], arr.shape[0]), dtype=arr.dtype, distaxis=0)
ofs = 0
for i in range(ntask):
lo, hi = _shareRange(arr.shape[0], ntask, i)
sz = rsz[i]//arr._data.itemsize
arrnew._data[:,lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo,-1).T
ofs += sz
return arrnew
def default_distaxis():
return 0
......@@ -75,17 +75,54 @@ class RGRGTransformation(Transformation):
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
rem_axes = tuple(i for i in axes if i != oldax)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=rem_axes)
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
if True:
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
# new, experimental code
if True:
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],-1))
shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
else:
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
tmp = dobj.redistribute(tmp, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else:
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
rem_axes = tuple(i for i in axes if i != oldax)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=rem_axes)
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
Tval = Field(tdom, tmp)
fct = self.fct_p2h if p2h else self.fct_h2p
if fct != 1:
......
......@@ -158,13 +158,56 @@ def hartley(a, axes=None):
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if issubclass(a.dtype.type, np.complexfloating):
raise TypeError("Hartley tansform requires real-valued arrays.")
raise TypeError("Hartley transform requires real-valued arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes)
return _fill_array(tmp, np.empty_like(a), axes)
def _fill_upper_half_complex(tmp, res, axes):
lastaxis = axes[-1]
nlast = res.shape[lastaxis]
ntmplast = tmp.shape[lastaxis]
nrem = nlast - ntmplast
slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
for i in axes[:-1]:
slice1[i] = slice(1, None)
slice2[i] = slice(None, 0, -1)
#np.conjugate(tmp[slice2], out=res[slice1])
res[slice1] = np.conjugate(tmp[slice2])
for i, ax in enumerate(axes[:-1]):
dim1 = [slice(None)]*ax + [slice(0, 1)]
axes2 = axes[:i] + axes[i+1:]
_fill_upper_half_complex(tmp[dim1], res[dim1], axes2)
def _fill_complex_array(tmp, res, axes):
if axes is None:
axes = tuple(range(tmp.ndim))
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
res[slice1] = tmp
_fill_upper_half_complex(tmp, res, axes)
return res
# Do a real-to-complex forward FFT and return the _full_ output array
def my_fftn_r2c(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if issubclass(a.dtype.type, np.complexfloating):
raise TypeError("Transform requires real-valued input arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes)
return _fill_complex_array(tmp, np.empty_like(a,dtype=tmp.dtype), axes)
def general_axpy(a, x, y, out):
if x.domain != y.domain or x.domain != out.domain:
raise ValueError("Incompatible domains")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment