Skip to content
Snippets Groups Projects
stress.py 3.54 KiB
import numpy as np
import pypocketfft


def _l2error(a, b):
    return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))


def fftn(a, axes=None, inorm=0, out=None, nthreads=1):
    return pypocketfft.c2c(a, axes=axes, forward=True, inorm=inorm,
                           out=out, nthreads=nthreads)


def ifftn(a, axes=None, inorm=0, out=None, nthreads=1):
    return pypocketfft.c2c(a, axes=axes, forward=False, inorm=inorm,
                           out=out, nthreads=nthreads)


def rfftn(a, axes=None, inorm=0, nthreads=1):
    return pypocketfft.r2c(a, axes=axes, forward=True, inorm=inorm,
                           nthreads=nthreads)


def irfftn(a, axes=None, lastsize=0, inorm=0, nthreads=1):
    return pypocketfft.c2r(a, axes=axes, lastsize=lastsize, forward=False,
                           inorm=inorm, nthreads=nthreads)


nthreads = 0


def update_err(err, name, value):
    if name not in err:
        err[name] = value
        print("{}: {}".format(name, value))
    else:
        if value > err[name]:
            err[name] = value
            print("{}: {}".format(name, value))
    return err


def test(err):
    ndim = np.random.randint(1, 5)
    axlen = int((2**20)**(1./ndim))
    shape = np.random.randint(1, axlen, ndim)
    axes = np.arange(ndim)
    np.random.shuffle(axes)
    nax = np.random.randint(1, ndim+1)
    axes = axes[:nax]
    lastsize = shape[axes[-1]]
    a = np.random.rand(*shape)-0.5 + 1j*np.random.rand(*shape)-0.5j
    b = ifftn(fftn(a, axes=axes, nthreads=nthreads), axes=axes, inorm=2,
              nthreads=nthreads)
    err = update_err(err, "cmax", _l2error(a, b))
    b = ifftn(fftn(a.real, axes=axes, nthreads=nthreads), axes=axes, inorm=2,
              nthreads=nthreads)
    err = update_err(err, "cmax", _l2error(a.real, b))
    b = fftn(ifftn(a.real, axes=axes, nthreads=nthreads), axes=axes, inorm=2,
             nthreads=nthreads)
    err = update_err(err, "cmax", _l2error(a.real, b))
    b = irfftn(rfftn(a.real, axes=axes, nthreads=nthreads), axes=axes, inorm=2,
               lastsize=lastsize, nthreads=nthreads)
    err = update_err(err, "rmax", _l2error(a.real, b))
    b = ifftn(fftn(a.astype(np.complex64), axes=axes, nthreads=nthreads),
              axes=axes, inorm=2, nthreads=nthreads)
    err = update_err(err, "cmaxf", _l2error(a.astype(np.complex64), b))
    b = irfftn(rfftn(a.real.astype(np.float32), axes=axes, nthreads=nthreads),
               axes=axes, inorm=2, lastsize=lastsize, nthreads=nthreads)
    err = update_err(err, "rmaxf", _l2error(a.real.astype(np.float32), b))
    b = pypocketfft.separable_hartley(
        pypocketfft.separable_hartley(a.real, axes=axes, nthreads=nthreads),
        axes=axes, inorm=2, nthreads=nthreads)
    err = update_err(err, "hmax", _l2error(a.real, b))
    b = pypocketfft.genuine_hartley(
        pypocketfft.genuine_hartley(a.real, axes=axes, nthreads=nthreads),
        axes=axes, inorm=2, nthreads=nthreads)
    err = update_err(err, "hmax", _l2error(a.real, b))
    b = pypocketfft.separable_hartley(
            pypocketfft.separable_hartley(
                a.real.astype(np.float32), axes=axes, nthreads=nthreads),
            axes=axes, inorm=2, nthreads=nthreads)
    err = update_err(err, "hmaxf", _l2error(a.real.astype(np.float32), b))
    b = pypocketfft.genuine_hartley(
            pypocketfft.genuine_hartley(a.real.astype(np.float32), axes=axes,
                                        nthreads=nthreads),
            axes=axes, inorm=2, nthreads=nthreads)
    err = update_err(err, "hmaxf", _l2error(a.real.astype(np.float32), b))


err = dict()
while True:
    test(err)