Commit 71da3a26 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve FFT handling; still more work to do

parent ee4725a6
Pipeline #21787 passed with stage
in 4 minutes and 15 seconds
......@@ -101,7 +101,7 @@ In oder to run the tests one needs two additional packages:
Afterwards the tests (including a coverage report) are run using the following
command in the repository root:
nosetests --exe --cover-html
nosetests -x --with-coverage --cover-html --cover-package=nifty2go
### First Steps
......
......@@ -62,15 +62,23 @@ class RGRGTransformation(Transformation):
p2h = x.domain == self.pdom
tdom = self.hdom if p2h else self.pdom
oldax = dobj.distaxis(x.val)
if dobj.distaxis(x.val) in axes:
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
if len(axes) == 1: # only one transform needed
ldat = utilities.hartley(ldat, axes=(oldax,))
if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
else: # we need redistribution and 1 or 2 FFT steps
if len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate transforms needed, "real" FFT required
else: # two separate, full FFTs needed
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
......@@ -79,10 +87,6 @@ class RGRGTransformation(Transformation):
ldat = fftn(ldat, axes=rem_axes)
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
else:
ldat = dobj.local_data(x.val)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
Tval = Field(tdom, tmp)
fct = self.fct_p2h if p2h else self.fct_h2p
if fct != 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment