Commit fd0d36e8 authored by Martin Reinecke's avatar Martin Reinecke

enable experimental pypocketfft support

parent 35dc4f9b
Pipeline #47214 passed with stages
in 16 minutes and 45 seconds
...@@ -15,6 +15,7 @@ RUN apt-get update && apt-get install -y \ ...@@ -15,6 +15,7 @@ RUN apt-get update && apt-get install -y \
&& pip3 install pyfftw \ && pip3 install pyfftw \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \ && pip3 install git+https://gitlab.mpcdf.mpg.de/ift/pyHealpix.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \ && pip3 install git+https://gitlab.mpcdf.mpg.de/ift/nifty_gridder.git \
&& pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft.git \
&& pip3 install jupyter \ && pip3 install jupyter \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from .utilities import iscomplextype from .utilities import iscomplextype
import numpy as np import numpy as np
import pypocketfft
_use_fftw = False _use_fftw = False
...@@ -51,13 +52,25 @@ def _init_pyfftw(): ...@@ -51,13 +52,25 @@ def _init_pyfftw():
_fftw_prepped = True _fftw_prepped = True
# FIXME this should not be necessary ... no one should call a complex FFT
# with a float array.
def _make_complex(a):
if a.dtype in (np.complex64, np.complex128):
return a
if a.dtype == np.float64:
return a.astype(np.complex128)
if a.dtype == np.float32:
return a.astype(np.complex64)
raise NotImplementedError
def fftn(a, axes=None): def fftn(a, axes=None):
if _use_fftw: if _use_fftw:
from pyfftw.interfaces.numpy_fft import fftn from pyfftw.interfaces.numpy_fft import fftn
_init_pyfftw() _init_pyfftw()
return fftn(a, axes=axes, **_fft_extra_args) return fftn(a, axes=axes, **_fft_extra_args)
else: else:
return np.fft.fftn(a, axes=axes) return pypocketfft.fftn(_make_complex(a), axes=axes)
def rfftn(a, axes=None): def rfftn(a, axes=None):
...@@ -66,7 +79,10 @@ def rfftn(a, axes=None): ...@@ -66,7 +79,10 @@ def rfftn(a, axes=None):
_init_pyfftw() _init_pyfftw()
return rfftn(a, axes=axes, **_fft_extra_args) return rfftn(a, axes=axes, **_fft_extra_args)
else: else:
return np.fft.rfftn(a, axes=axes) if axes is None:
return pypocketfft.rfftn(a)
else:
return pypocketfft.rfftn(a, axes=axes)
def ifftn(a, axes=None): def ifftn(a, axes=None):
...@@ -75,7 +91,12 @@ def ifftn(a, axes=None): ...@@ -75,7 +91,12 @@ def ifftn(a, axes=None):
_init_pyfftw() _init_pyfftw()
return ifftn(a, axes=axes, **_fft_extra_args) return ifftn(a, axes=axes, **_fft_extra_args)
else: else:
return np.fft.ifftn(a, axes=axes) # FIXME this is a temporary fix and can be done more elegantly
if axes is None:
fct = 1./a.size
else:
fct = 1./np.prod(np.take(a.shape, axes))
return pypocketfft.ifftn(_make_complex(a), axes=axes, fct=fct)
def hartley(a, axes=None): def hartley(a, axes=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment