Commit 0cb56087 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'ducc_vdot' into 'NIFTy_8'

Use ducc's vdot if available (more accurate)

See merge request !653
parents 7f4fe0b7 07f97c95
Pipeline #105182 passed with stages
in 30 minutes and 38 seconds
......@@ -7,6 +7,14 @@ Jax interface
The interfaces `ift.JaxOperator` and `ift.JaxLikelihoodEnergyOperator` provide
an interface to jax.
Interface change for nthreads
-----------------------------
The number of threads, which are used for the FFTs and ducc in general, used to
be set via `ift.fft.set_nthreads(n)`. This has been moved to
`ift.set_nthreads(n)`. Similarly, `ift.fft.nthreads()` -> `ift.nthreads()`.
Changes since NIFTy 6
=====================
......
......@@ -108,5 +108,7 @@ from .operator_spectrum import operator_spectrum
from .operator_tree_optimiser import optimise_operator
from .ducc_dispatch import set_nthreads, nthreads
# We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty8 import *"; that would swamp the global namespace.
......@@ -11,11 +11,12 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
_nthreads = 1
......@@ -30,6 +31,7 @@ def set_nthreads(nthr):
try:
import ducc0.fft as my_fft
import ducc0.misc
def fftn(a, axes=None):
......@@ -44,6 +46,14 @@ try:
def hartley(a, axes=None):
return my_fft.genuine_hartley(a, axes=axes, nthreads=max(_nthreads, 0))
def vdot(a, b):
if isinstance(a, np.ndarray) and a.dtype == np.int64:
a = a.astype(np.float64)
if isinstance(b, np.ndarray) and b.dtype == np.int64:
b = b.astype(np.float64)
return ducc0.misc.vdot(a, b)
except ImportError:
import scipy.fft
......@@ -59,3 +69,11 @@ except ImportError:
def hartley(a, axes=None):
tmp = scipy.fft.fftn(a, axes=axes, workers=_nthreads)
return tmp.real+tmp.imag
def vdot(a, b):
from .logger import logger
if (isinstance(a, np.ndarray) and a.dtype == np.float32) or \
(isinstance(b, np.ndarray) and b.dtype == np.float32):
logger.warning("Calling np.vdot in single precision may lead to inaccurate results")
return np.vdot(a, b)
......@@ -22,6 +22,7 @@ import numpy as np
from . import utilities
from .domain_tuple import DomainTuple
from .operators.operator import Operator
from .ducc_dispatch import vdot
class Field(Operator):
......@@ -316,7 +317,7 @@ class Field(Operator):
spaces = utilities.parse_spaces(spaces, ndom)
if len(spaces) == ndom:
return Field.scalar(np.array(np.vdot(self._val, x._val)))
return Field.scalar(np.array(vdot(self._val, x._val)))
# If we arrive here, we have to do a partial dot product.
# For the moment, do this the explicit, non-optimized way
return (self.conjugate()*x).sum(spaces=spaces)
......@@ -341,7 +342,7 @@ class Field(Operator):
if x._domain != self._domain:
raise ValueError("Domain mismatch")
return np.vdot(self._val, x._val)
return vdot(self._val, x._val)
def norm(self, ord=2):
"""Computes the L2-norm of the field values.
......
......@@ -21,7 +21,7 @@ from scipy.constants import speed_of_light
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..fft import nthreads
from ..ducc_dispatch import nthreads
from ..operators.linear_operator import LinearOperator
from ..sugar import makeDomain, makeField
......
......@@ -17,12 +17,13 @@
import numpy as np
from .. import fft, utilities
from .. import utilities
from ..domain_tuple import DomainTuple
from ..domains.gl_space import GLSpace
from ..domains.lm_space import LMSpace
from ..domains.rg_space import RGSpace
from ..field import Field
from ..ducc_dispatch import fftn, ifftn, hartley
from .diagonal_operator import DiagonalOperator
from .linear_operator import LinearOperator
from .scaling_operator import ScalingOperator
......@@ -74,10 +75,10 @@ class FFTOperator(LinearOperator):
self._check_input(x, mode)
ncells = x.domain[self._space].size
if x.domain[self._space].harmonic: # harmonic -> position
func = fft.ifftn
func = ifftn
fct = ncells
else:
func = fft.fftn
func = fftn
fct = 1.
axes = x.domain.axes[self._space]
tdom = self._tgt(mode)
......@@ -148,7 +149,7 @@ class HartleyOperator(LinearOperator):
def _apply_cartesian(self, x, mode):
axes = x.domain.axes[self._space]
tdom = self._tgt(mode)
tmp = fft.hartley(x.val, axes=axes)
tmp = hartley(x.val, axes=axes)
Tval = Field(tdom, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct = self._domain[self._space].scalar_dvol
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
import numpy as np
from ducc0.misc import vdot
import pytest
from .common import setup_function, teardown_function
dt = np.float32
def _check(a, b):
res0 = np.vdot(a, b)
res1 = vdot(a, b)
rel_error = np.abs((res0-res1)/(res0+res1)*2)
assert rel_error < 1e-6
# When this is fixed in numpy, the warning in src/ducc_dispath.py is no longer
# necessary
@pytest.mark.xfail(reason="np.vdot inaccurate for single precision", strict=True)
def test_vdot():
a = 100*np.ones((1000000,)).astype(dt)
_check(a, a)
@pytest.mark.xfail(reason="np.vdot inaccurate for single precision", strict=True)
def test_vdot_extreme():
a = np.array([1e8, 1, -1e8]).astype(dt)
b = np.array([1e8, 1, 1e8]).astype(dt)
_check(a, b)
......@@ -62,8 +62,8 @@ def test_fft1D(d, dtype, op):
@pmp('d2', [0.4, 1, 2.7])
@pmp('nthreads', [-1, 1, 2, 3, 4])
def test_fft2D(dim1, dim2, d1, d2, dtype, op, nthreads):
ift.fft.set_nthreads(nthreads)
ift.myassert(ift.fft.nthreads() == nthreads)
ift.set_nthreads(nthreads)
ift.myassert(ift.nthreads() == nthreads)
tol = _get_rtol(dtype)
a = ift.RGSpace([dim1, dim2], distances=[d1, d2])
b = ift.RGSpace(
......@@ -80,7 +80,7 @@ def test_fft2D(dim1, dim2, d1, d2, dtype, op, nthreads):
inp = ift.Field.from_random(domain=a, random_type='normal', dtype=dtype, std=7, mean=3)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
ift.fft.set_nthreads(1)
ift.set_nthreads(1)
@pmp('index', [0, 1, 2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment