Commit f09cf8f2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

nicer structure

parent 71da3a26
Pipeline #21790 passed with stage
in 4 minutes and 17 seconds
......@@ -66,27 +66,26 @@ class RGRGTransformation(Transformation):
ldat = dobj.local_data(x.val)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
else: # we need redistribution and 1 or 2 FFT steps
if len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
rem_axes = tuple(i for i in axes if i != oldax)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=rem_axes)
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
rem_axes = tuple(i for i in axes if i != oldax)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=rem_axes)
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
Tval = Field(tdom, tmp)
fct = self.fct_p2h if p2h else self.fct_h2p
if fct != 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment