Skip to content
Snippets Groups Projects
Commit 7342ee68 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

support nthread selection

parent 36b04d9c
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ import pypocketfft ...@@ -5,6 +5,7 @@ import pypocketfft
from time import time from time import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
nthreads=0
def _l2error(a,b): def _l2error(a,b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
...@@ -22,7 +23,7 @@ def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""): ...@@ -22,7 +23,7 @@ def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""):
tmin_pp=1e38 tmin_pp=1e38
for i in range(nrepeat): for i in range(nrepeat):
t0=time() t0=time()
b=pypocketfft.fftn(a) b=pypocketfft.fftn(a,nthreads=nthreads)
t1=time() t1=time()
tmin_pp = min(tmin_pp,t1-t0) tmin_pp = min(tmin_pp,t1-t0)
a2=pypocketfft.ifftn(b,fct=1./a.size) a2=pypocketfft.ifftn(b,fct=1./a.size)
......
...@@ -4,7 +4,7 @@ import pypocketfft ...@@ -4,7 +4,7 @@ import pypocketfft
def _l2error(a,b): def _l2error(a,b):
return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
nthreads=0
cmaxerr=0. cmaxerr=0.
fmaxerr=0. fmaxerr=0.
cmaxerrf=0. cmaxerrf=0.
...@@ -23,32 +23,32 @@ def test(): ...@@ -23,32 +23,32 @@ def test():
lastsize = shape[axes[-1]] lastsize = shape[axes[-1]]
fct = 1./np.prod(np.take(shape, axes)) fct = 1./np.prod(np.take(shape, axes))
a=np.random.rand(*shape)-0.5 + 1j*np.random.rand(*shape)-0.5j a=np.random.rand(*shape)-0.5 + 1j*np.random.rand(*shape)-0.5j
b=pypocketfft.ifftn(pypocketfft.fftn(a,axes=axes),axes=axes,fct=fct) b=pypocketfft.ifftn(pypocketfft.fftn(a,axes=axes,nthreads=nthreads),axes=axes,fct=fct,nthreads=nthreads)
err = _l2error(a,b) err = _l2error(a,b)
if err > cmaxerr: if err > cmaxerr:
cmaxerr = err cmaxerr = err
print("cmaxerr:", cmaxerr, shape, axes) print("cmaxerr:", cmaxerr, shape, axes)
b=pypocketfft.irfftn(pypocketfft.rfftn(a.real,axes=axes),axes=axes,fct=fct,lastsize=lastsize) b=pypocketfft.irfftn(pypocketfft.rfftn(a.real,axes=axes,nthreads=nthreads),axes=axes,fct=fct,lastsize=lastsize,nthreads=nthreads)
err = _l2error(a.real,b) err = _l2error(a.real,b)
if err > fmaxerr: if err > fmaxerr:
fmaxerr = err fmaxerr = err
print("fmaxerr:", fmaxerr, shape, axes) print("fmaxerr:", fmaxerr, shape, axes)
b=pypocketfft.ifftn(pypocketfft.fftn(a.astype(np.complex64),axes=axes),axes=axes,fct=fct) b=pypocketfft.ifftn(pypocketfft.fftn(a.astype(np.complex64),axes=axes,nthreads=nthreads),axes=axes,fct=fct,nthreads=nthreads)
err = _l2error(a.astype(np.complex64),b) err = _l2error(a.astype(np.complex64),b)
if err > cmaxerrf: if err > cmaxerrf:
cmaxerrf = err cmaxerrf = err
print("cmaxerrf:", cmaxerrf, shape, axes) print("cmaxerrf:", cmaxerrf, shape, axes)
b=pypocketfft.irfftn(pypocketfft.rfftn(a.real.astype(np.float32),axes=axes),axes=axes,fct=fct,lastsize=lastsize) b=pypocketfft.irfftn(pypocketfft.rfftn(a.real.astype(np.float32),axes=axes,nthreads=nthreads),axes=axes,fct=fct,lastsize=lastsize,nthreads=nthreads)
err = _l2error(a.real.astype(np.float32),b) err = _l2error(a.real.astype(np.float32),b)
if err > fmaxerrf: if err > fmaxerrf:
fmaxerrf = err fmaxerrf = err
print("fmaxerrf:", fmaxerrf, shape, axes) print("fmaxerrf:", fmaxerrf, shape, axes)
b=pypocketfft.hartley(pypocketfft.hartley(a.real,axes=axes),axes=axes,fct=fct) b=pypocketfft.hartley(pypocketfft.hartley(a.real,axes=axes,nthreads=nthreads),axes=axes,fct=fct,nthreads=nthreads)
err = _l2error(a.real,b) err = _l2error(a.real,b)
if err > hmaxerr: if err > hmaxerr:
hmaxerr = err hmaxerr = err
print("hmaxerr:", hmaxerr, shape, axes) print("hmaxerr:", hmaxerr, shape, axes)
b=pypocketfft.hartley(pypocketfft.hartley(a.real.astype(np.float32),axes=axes),axes=axes,fct=fct) b=pypocketfft.hartley(pypocketfft.hartley(a.real.astype(np.float32),axes=axes,nthreads=nthreads),axes=axes,fct=fct,nthreads=nthreads)
err = _l2error(a.real.astype(np.float32),b) err = _l2error(a.real.astype(np.float32),b)
if err > hmaxerrf: if err > hmaxerrf:
hmaxerrf = err hmaxerrf = err
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment