Commit ba08fbe1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

merge master

parents 11c94b91 a220fe56
Pipeline #13349 passed with stage
in 5 minutes and 17 seconds
......@@ -28,17 +28,25 @@ __all__ = ['dependency_injector', 'nifty_configuration']
# Setup the dependency injector
dependency_injector = keepers.DependencyInjector(
[('mpi4py.MPI', 'MPI'),
('pyfftw', 'fftw'),
'pyHealpix',
'plotly'])
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
def _fft_module_checker(z):
if z == 'fftw_mpi':
return hasattr(dependency_injector.get('fftw'), 'FFTW_MPI')
if z == 'fftw':
return ('fftw' in dependency_injector)
if z == 'numpy':
return True
return False
# Initialize the variables
variable_fft_module = keepers.Variable(
'fft_module',
['fftw', 'numpy'],
lambda z: (('pyfftw' in dependency_injector)
if z == 'fftw' else True))
['fftw_mpi', 'fftw', 'numpy'],
_fft_module_checker)
def dtype_validator(dtype):
......
......@@ -65,7 +65,6 @@ class LineSearchStrongWolfe(LineSearch):
max_step_size=50, max_iterations=10,
max_zoom_iterations=10):
super(LineSearchStrongWolfe, self).__init__()
self.c1 = np.float(c1)
......@@ -308,7 +307,8 @@ class LineSearchStrongWolfe(LineSearch):
"""Estimating the minimum with cubic interpolation.
Finds the minimizer for a cubic polynomial that goes through the
points ( a,f(a) ), ( b,f(b) ), and ( c,f(c) ) with derivative at point a of fpa.
points ( a,f(a) ), ( b,f(b) ), and ( c,f(c) ) with derivative at point
a of fpa.
f(x) = A *(x-a)^3 + B*(x-a)^2 + C*(x-a) + D
If no minimizer can be found return None
......@@ -341,12 +341,12 @@ class LineSearchStrongWolfe(LineSearch):
C = fpa
db = b - a
dc = c - a
denom = (db * dc) ** 2 * (db - dc)
denom = db * db * dc * dc * (db - dc)
d1 = np.empty((2, 2))
d1[0, 0] = dc ** 2
d1[0, 1] = -db ** 2
d1[1, 0] = -dc ** 3
d1[1, 1] = db ** 3
d1[0, 0] = dc * dc
d1[0, 1] = -(db*db)
d1[1, 0] = -(dc*dc*dc)
d1[1, 1] = db*db*db
[A, B] = np.dot(d1, np.asarray([fb - fa - C * db,
fc - fa - C * dc]).flatten())
A /= denom
......
......@@ -63,9 +63,10 @@ class FFTOperator(LinearOperator):
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is
always available, but "fftw" offers higher performance and
parallelization. For sphere-related domains, only "pyHealpix" is
For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
always available (using pyfftw if available, else numpy.fft), and "mpi"
requires pyfftw and offers MPI parallelization.
For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional)
......
......@@ -25,7 +25,7 @@ import nifty.nifty_utilities as utilities
from keepers import Loggable
pyfftw = gdi.get('pyfftw')
fftw = gdi.get('fftw')
class Transform(Loggable, object):
......@@ -200,20 +200,21 @@ class Transform(Loggable, object):
raise NotImplementedError
class FFTW(Transform):
class MPIFFT(Transform):
"""
The pyfftw pendant of a fft object.
The MPI-parallel FFTW pendant of a fft object.
"""
def __init__(self, domain, codomain):
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
super(FFTW, self).__init__(domain, codomain)
super(MPIFFT, self).__init__(domain, codomain)
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
# Enable caching
fftw.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
......@@ -409,7 +410,7 @@ class FFTW(Transform):
def transform(self, val, axes, **kwargs):
"""
The pyfftw transform function.
The MPI-parallel FFTW transform function.
Parameters
----------
......@@ -467,8 +468,9 @@ class FFTW(Transform):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
......@@ -512,9 +514,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context,
**kwargs)
if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
self._fftw_interface = fftw.interfaces.numpy_fft.fftn
else:
self._fftw_interface = pyfftw.interfaces.numpy_fft.ifftn
self._fftw_interface = fftw.interfaces.numpy_fft.ifftn
@property
def fftw_interface(self):
......@@ -531,7 +533,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q,
fftw_context,
**kwargs)
self._plan = pyfftw.create_mpi_plan(
self._plan = fftw.create_mpi_plan(
input_shape=transform_shape,
input_dtype='complex128',
output_dtype='complex128',
......@@ -545,15 +547,26 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return self._plan
class NUMPYFFT(Transform):
class SerialFFT(Transform):
"""
The numpy fft pendant of a fft object.
"""
def __init__(self, domain, codomain, use_fftw):
super(SerialFFT, self).__init__(domain, codomain)
if use_fftw and (fftw is None):
raise ImportError(
"The serial FFTW module is needed but not available.")
self._use_fftw = use_fftw
# Enable caching
if self._use_fftw:
fftw.interfaces.cache.enable()
def transform(self, val, axes, **kwargs):
"""
The pyfftw transform function.
The scalar FFT transform function.
Parameters
----------
......@@ -572,6 +585,7 @@ class NUMPYFFT(Transform):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
......@@ -625,6 +639,14 @@ class NUMPYFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if self._use_fftw:
if self.codomain.harmonic:
result_val = fftw.interfaces.numpy_fft.fftn(
local_val, axes=axes)
else:
result_val = fftw.interfaces.numpy_fft.ifftn(
local_val, axes=axes)
else:
if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
else:
......
......@@ -18,7 +18,7 @@
import numpy as np
from transformation import Transformation
from rg_transforms import FFTW, NUMPYFFT
from rg_transforms import MPIFFT, SerialFFT
from nifty import RGSpace, nifty_configuration
......@@ -30,18 +30,16 @@ class RGRGTransformation(Transformation):
super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None:
if nifty_configuration['fft_module'] == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
elif nifty_configuration['fft_module'] == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
else:
raise ValueError('Unsupported default FFT module:' +
nifty_configuration['fft_module'])
else:
if module == 'fftw':
self._transform = FFTW(self.domain, self.codomain)
module = nifty_configuration['fft_module']
if module == 'fftw_mpi':
self._transform = MPIFFT(self.domain, self.codomain)
elif module == 'fftw':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=True)
elif module == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=False)
else:
raise ValueError('Unsupported FFT module:' + module)
......@@ -118,7 +116,7 @@ class RGRGTransformation(Transformation):
np.absolute(np.array(domain.shape) *
np.array(domain.distances) *
np.array(codomain.distances) - 1) <
10**-7):
1e-7):
raise AttributeError("The grid-distances of domain and codomain "
"do not match.")
......
......@@ -53,10 +53,13 @@ class DirectSmoothingOperator(SmoothingOperator):
wgt = []
expfac = 1. / (2. * sigma*sigma)
for i in range(x.size):
if nval[i]>0:
t = x[ibegin[i]:ibegin[i]+nval[i]]-x[i]
t = np.exp(-t*t*expfac)
t *= 1./np.sum(t)
wgt.append(t)
else:
wgt.append(np.array([]))
return ibegin, nval, wgt
......@@ -146,7 +149,7 @@ class DirectSmoothingOperator(SmoothingOperator):
#MR FIXME: this causes calls of log(0.) which should probably be avoided
if self.log_distances:
np.log(distance_array, out=distance_array)
np.log(np.maximum(distance_array,1e-15), out=distance_array)
# collect the local data + ghost cells
local_data_Q = False
......
......@@ -93,11 +93,11 @@ class HPSpace(Space):
@property
def shape(self):
return (np.int(12 * self.nside ** 2),)
return (np.int(12 * self.nside * self.nside),)
@property
def dim(self):
return np.int(12 * self.nside ** 2)
return np.int(12 * self.nside * self.nside)
@property
def total_volume(self):
......@@ -108,7 +108,7 @@ class HPSpace(Space):
def weight(self, x, power=1, axes=None, inplace=False):
weight = ((4 * np.pi) / (12 * self.nside**2)) ** np.float(power)
weight = ((4*np.pi) / (12*self.nside*self.nside)) ** np.float(power)
if inplace:
x *= weight
......
......@@ -130,7 +130,7 @@ class LMSpace(Space):
# dim = (((2*(l+1)-1)+1)**2/4 - 2 * (l-m)(l-m+1)/2
# dim = np.int((l+1)**2 - (l-m)*(l-m+1.))
# We fix l == m
return np.int((l+1)**2)
return np.int((l+1)*(l+1))
@property
def total_volume(self):
......@@ -166,7 +166,7 @@ class LMSpace(Space):
def get_fft_smoothing_kernel_function(self, sigma):
# FIXME why x(x+1) ? add reference to paper!
return lambda x: np.exp(-0.5 * x * (x + 1) * sigma**2)
return lambda x: np.exp(-0.5 * x * (x + 1) * sigma*sigma)
# ---Added properties and methods---
......
......@@ -267,14 +267,16 @@ class RGSpace(Space):
cords = np.ogrid[inds]
dists = ((cords[0] - shape[0]//2)*dk[0])**2
dists = (cords[0] - shape[0]//2)*dk[0]
dists *= dists
# apply zerocenterQ shift
if not self.zerocenter[0]:
dists = np.fft.ifftshift(dists)
# only save the individual slice
dists = dists[slice_of_first_dimension]
for ii in range(1, len(shape)):
temp = ((cords[ii] - shape[ii] // 2) * dk[ii])**2
temp = (cords[ii] - shape[ii] // 2) * dk[ii]
temp *= temp
if not self.zerocenter[ii]:
temp = np.fft.ifftshift(temp)
dists = dists + temp
......@@ -282,7 +284,7 @@ class RGSpace(Space):
return dists
def get_fft_smoothing_kernel_function(self, sigma):
return lambda x: np.exp(-0.5 * np.pi**2 * x**2 * sigma**2)
return lambda x: np.exp(-0.5 * np.pi*np.pi * x*x * sigma*sigma)
# ---Added properties and methods---
......
......@@ -16,9 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from nose_parameterized import parameterized
from parameterized import parameterized
from nifty import RGSpace, LMSpace, HPSpace, GLSpace, PowerSpace
from nifty.config import dependency_injector as di
from nifty.config import dependency_injector as gdi
def custom_name_func(testcase_func, param_num, param):
......@@ -36,7 +36,7 @@ def expand(*args, **kwargs):
def generate_spaces():
spaces = [RGSpace(4), PowerSpace(RGSpace((4, 4), harmonic=True)),
LMSpace(5), HPSpace(4)]
if 'pyHealpix' in di:
if 'pyHealpix' in gdi:
spaces.append(GLSpace(4))
return spaces
......
from nifty import *
#This tests if it is possible to import all of nifties methods. Experience shows this is not always possible.
# This tests if it is possible to import all of Nifty's methods.
# Experience shows this is not always possible.
pass
......@@ -20,7 +20,7 @@ import unittest
import numpy as np
from numpy.testing import assert_equal,\
assert_allclose
from nifty.config import dependency_injector as di
from nifty.config import dependency_injector as gdi
from nifty import Field,\
RGSpace,\
LMSpace,\
......@@ -62,11 +62,15 @@ class FFTOperatorTests(unittest.TestCase):
res = foo.get_distance_array('not')
assert_equal(res[zc1 * (dim1 // 2), zc2 * (dim2 // 2)], 0.)
@expand(product(["numpy", "fftw"], [10, 11], [False, True], [False, True],
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [False, True], [False, True],
[0.1, 1, 3.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft1D(self, module, dim1, zc1, zc2, d, itp):
if module == "fftw" and "pyfftw" not in di:
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d)
......@@ -78,12 +82,16 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw"], [10, 11], [9, 12], [False, True],
@expand(product(["numpy", "fftw", "fftw_mpi"],
[10, 11], [9, 12], [False, True],
[False, True], [False, True], [False, True], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_fft2D(self, module, dim1, dim2, zc1, zc2, zc3, zc4, d1, d2, itp):
if module == "fftw" and "pyfftw" not in di:
if module == "fftw_mpi":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], zerocenter=[zc1, zc2], distances=[d1, d2])
......@@ -99,7 +107,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([0, 3, 6, 11, 30],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_sht(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......@@ -113,7 +121,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_sht2(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
......@@ -126,7 +134,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......@@ -142,7 +150,7 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.complex128, np.float32, np.complex64]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
......
......@@ -24,7 +24,7 @@ from numpy.testing import assert_, assert_equal, assert_raises,\
assert_almost_equal
from nose.plugins.skip import SkipTest
from nifty import GLSpace
from nifty.config import dependency_injector as di
from nifty.config import dependency_injector as gdi
from test.common import expand
# [nlat, nlon, expected]
......@@ -99,7 +99,7 @@ class GLSpaceFunctionalityTests(unittest.TestCase):
except ImportError:
raise SkipTest
if 'pyHealpix' not in di:
if 'pyHealpix' not in gdi:
raise SkipTest
else:
g = GLSpace(2)
......
......@@ -28,19 +28,20 @@ from nifty import PowerSpace, RGSpace, Space, LMSpace
from types import NoneType
from test.common import expand
from itertools import product, chain
#needed to check wether fftw is available
from d2o.config import dependency_injector as gdi
# needed to check wether fftw is available
from nifty import dependency_injector as gdi
from nose.plugins.skip import SkipTest
HARMONIC_SPACES = [RGSpace((8,), harmonic=True),
RGSpace((7,), harmonic=True,zerocenter=True),
RGSpace((8,), harmonic=True,zerocenter=True),
RGSpace((7,8), harmonic=True),
RGSpace((7,8), harmonic=True, zerocenter=True),
RGSpace((6,6), harmonic=True, zerocenter=True),
RGSpace((7,5), harmonic=True, zerocenter=True),
RGSpace((5,5), harmonic=True),
RGSpace((4,5,7), harmonic=True),
RGSpace((4,5,7), harmonic=True, zerocenter=True),
RGSpace((7,), harmonic=True, zerocenter=True),
RGSpace((8,), harmonic=True, zerocenter=True),
RGSpace((7, 8), harmonic=True),
RGSpace((7, 8), harmonic=True, zerocenter=True),
RGSpace((6, 6), harmonic=True, zerocenter=True),
RGSpace((7, 5), harmonic=True, zerocenter=True),
RGSpace((5, 5), harmonic=True),
RGSpace((4, 5, 7), harmonic=True),
RGSpace((4, 5, 7), harmonic=True, zerocenter=True),
LMSpace(6),
LMSpace(9)]
......@@ -49,9 +50,14 @@ HARMONIC_SPACES = [RGSpace((8,), harmonic=True),
#binning parameters
_maybe_fftw = ["fftw"] if ('pyfftw' in gdi) else []
CONSISTENCY_CONFIGS_IMPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [None], [None, 3,4], [True, False])
CONSISTENCY_CONFIGS_EXPLICIT = product(HARMONIC_SPACES, ["not", "equal"] + _maybe_fftw, [[0.,1.3]],[None],[False])
CONSISTENCY_CONFIGS = chain(CONSISTENCY_CONFIGS_IMPLICIT, CONSISTENCY_CONFIGS_EXPLICIT)
CONSISTENCY_CONFIGS_IMPLICIT = product(HARMONIC_SPACES,
["not", "equal", "fftw"],
[None], [None, 3, 4], [True, False])
CONSISTENCY_CONFIGS_EXPLICIT = product(HARMONIC_SPACES,
["not", "equal", "fftw"],
[[0., 1.3]], [None], [False])
CONSISTENCY_CONFIGS = chain(CONSISTENCY_CONFIGS_IMPLICIT,
CONSISTENCY_CONFIGS_EXPLICIT)
# [harmonic_partner, distribution_strategy,
# logarithmic, nbin, binbounds, expected]
......@@ -121,14 +127,20 @@ class PowerSpaceInterfaceTest(unittest.TestCase):
p = PowerSpace(r)
assert_(isinstance(getattr(p, attribute), expected_type))
class PowerSpaceConsistencyCheck(unittest.TestCase):
@expand(CONSISTENCY_CONFIGS)
def test_rhopindexConsistency(self, harmonic_partner, distribution_strategy,
binbounds, nbin,logarithmic):
if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
print (gdi.get('fftw'), "blub \n\n\n")
raise SkipTest
p = PowerSpace(harmonic_partner=harmonic_partner,
distribution_strategy=distribution_strategy,
logarithmic=logarithmic, nbin=nbin,
binbounds=binbounds)
assert_equal(p.pindex.flatten().bincount(), p.rho,
err_msg='rho is not equal to pindex degeneracy')
......@@ -136,6 +148,10 @@ class PowerSpaceFunctionalityTest(unittest.TestCase):
@expand(CONSTRUCTOR_CONFIGS)
def test_constructor(self, harmonic_partner, distribution_strategy,
logarithmic, nbin, binbounds, expected):
if distribution_strategy == "fftw":
if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
raise SkipTest
raise SkipTest
if 'error' in expected:
with assert_raises(expected['error']):
PowerSpace(harmonic_partner=harmonic_partner,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment