Commit 2cc75057 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch '16-add-operator-tests' into 'NIFTy_5'

Resolve "Create consistency checks for all linear operators and gradient consistency checks for all energies in library"

See merge request ift/nifty-dev!60
parents fcbd1ea9 9a7e60bf
......@@ -42,6 +42,7 @@ from .operators.laplace_operator import LaplaceOperator
from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator
from .operators.multi_adaptor import MultiAdaptor
from .operators.null_operator import NullOperator
from .operators.power_distributor import PowerDistributor
from .operators.qht_operator import QHTOperator
from .operators.sampling_enabler import SamplingEnabler
......
......@@ -158,11 +158,11 @@ class data_object(object):
def prod(self, axis=None):
return self._contraction_helper("prod", MPI.PROD, axis)
def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis)
# def min(self, axis=None):
# return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis)
# def max(self, axis=None):
# return self._contraction_helper("max", MPI.MAX, axis)
def mean(self, axis=None):
if axis is None:
......@@ -348,6 +348,12 @@ def np_allreduce_min(arr):
return res
def np_allreduce_max(arr):
res = np.empty_like(arr)
_comm.Allreduce(arr, res, MPI.MAX)
return res
def distaxis(arr):
return arr._distaxis
......
......@@ -75,6 +75,10 @@ def np_allreduce_min(arr):
return arr
def np_allreduce_max(arr):
return arr
def distaxis(arr):
return -1
......
......@@ -33,6 +33,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"empty", "zeros", "ones", "empty_like", "vdot", "exp",
"log", "tanh", "sqrt", "from_object", "from_random",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked"]
......@@ -481,39 +481,39 @@ class Field(object):
def any(self, spaces=None):
return self._contraction_helper('any', spaces)
def min(self, spaces=None):
"""Determines the minimum over the sub-domains given by `spaces`.
Parameters
----------
spaces : None, int or tuple of int (default: None)
The operation is only carried out over the sub-domains in this
tuple. If None, it is carried out over all sub-domains.
Returns
-------
Field or scalar
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
"""
return self._contraction_helper('min', spaces)
def max(self, spaces=None):
"""Determines the maximum over the sub-domains given by `spaces`.
Parameters
----------
spaces : None, int or tuple of int (default: None)
The operation is only carried out over the sub-domains in this
tuple. If None, it is carried out over all sub-domains.
Returns
-------
Field or scalar
The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field.
"""
return self._contraction_helper('max', spaces)
# def min(self, spaces=None):
# """Determines the minimum over the sub-domains given by `spaces`.
#
# Parameters
# ----------
# spaces : None, int or tuple of int (default: None)
# The operation is only carried out over the sub-domains in this
# tuple. If None, it is carried out over all sub-domains.
#
# Returns
# -------
# Field or scalar
# The result of the operation. If it is carried out over the entire
# domain, this is a scalar, otherwise a Field.
# """
# return self._contraction_helper('min', spaces)
#
# def max(self, spaces=None):
# """Determines the maximum over the sub-domains given by `spaces`.
#
# Parameters
# ----------
# spaces : None, int or tuple of int (default: None)
# The operation is only carried out over the sub-domains in this
# tuple. If None, it is carried out over all sub-domains.
#
# Returns
# -------
# Field or scalar
# The result of the operation. If it is carried out over the entire
# domain, this is a scalar, otherwise a Field.
# """
# return self._contraction_helper('max', spaces)
def mean(self, spaces=None):
"""Determines the mean over the sub-domains given by `spaces`.
......
......@@ -41,9 +41,9 @@ class DOFDistributor(LinearOperator):
----------
dofdex: Field of integers
An integer Field on exactly one Space establishing the association
between the locations in the operators target and the underlying
between the locations in the operator's target and the underlying
degrees of freedom in its domain.
It has to start at 0 and it increases monotonicly, no empty bins are
It has to start at 0 and it increases monotonically, no empty bins are
allowed.
target: Domain, tuple of Domain, or DomainTuple, optional
The target of the operator. If not specified, the domain of the dofdex
......@@ -70,12 +70,17 @@ class DOFDistributor(LinearOperator):
if partner != dofdex.domain[0]:
raise ValueError("incorrect dofdex domain")
nbin = dofdex.max()
ldat = dofdex.local_data
if ldat.size==0: # can happen for weird configurations
nbin = 0
else:
nbin = ldat.max()
nbin = dobj.np_allreduce_max(np.array(nbin))[()] + 1
if partner.scalar_dvol is not None:
wgt = np.bincount(dofdex.local_data.ravel(), minlength=nbin)
wgt = wgt*partner.scalar_dvol
else:
dvol = dobj.local_data(partner.dvol)
dvol = Field.from_global_data(partner, partner.dvol).local_data
wgt = np.bincount(dofdex.local_data.ravel(),
minlength=nbin, weights=dvol)
# The explicit conversion to float64 is necessary because bincount
......
......@@ -76,7 +76,7 @@ class SlopeOperator(LinearOperator):
return Field.from_global_data(self.target, res)
# Adjoint times
res = np.zeros(self.domain[0].shape)
res = np.zeros(self.domain[0].shape, dtype=x.dtype)
xglob = x.to_global_data()
res[-1] = np.sum(xglob) * self._sigmas[-1]
for i in range(self.ndim):
......
......@@ -154,8 +154,8 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f1.local_data, +f1.local_data)
assert_equal(f1.sum(), f1.sum(0))
f1 = ift.from_global_data(s1, np.arange(10))
assert_equal(f1.min(), 0)
assert_equal(f1.max(), 9)
# assert_equal(f1.min(), 0)
# assert_equal(f1.max(), 9)
assert_equal(f1.prod(), 0)
def test_weight(self):
......
......@@ -35,6 +35,93 @@ _pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))]
class Consistency_Tests(unittest.TestCase):
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testOperatorCombinations(self, sp, dtype):
a = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
b = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
op = ift.SandwichOperator.make(a, b)
ift.extra.consistency_check(op, dtype, dtype)
op = a*b
ift.extra.consistency_check(op, dtype, dtype)
op = a+b
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
harmonic=True), 3, 0),
(ift.LMSpace(4), 10, 0)],
[np.float64, np.complex128]))
def testSlopeOperator(self, args, dtype):
tmp = ift.ExpTransform(ift.PowerSpace(args[0]), args[1], args[2])
tgt = tmp.domain[0]
sig = np.array([0.3, 0.13])
dom = ift.UnstructuredDomain(2)
op = ift.SlopeOperator(dom, tgt, sig)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testSelectionOperator(self, sp1, sp2, dtype):
mdom = ift.MultiDomain.make({'a':sp1, 'b':sp2})
op = ift.SelectionOperator(mdom, 'a')
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testOperatorAdaptor(self, sp, dtype):
op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
ift.extra.consistency_check(op.adjoint, dtype, dtype)
ift.extra.consistency_check(op.inverse, dtype, dtype)
ift.extra.consistency_check(op.inverse.adjoint, dtype, dtype)
ift.extra.consistency_check(op.adjoint.inverse, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testNullOperator(self, sp1, sp2, dtype):
op = ift.NullOperator(sp1, sp2)
ift.extra.consistency_check(op, dtype, dtype)
mdom1 = ift.MultiDomain.make({'a':sp1})
mdom2 = ift.MultiDomain.make({'b':sp2})
op = ift.NullOperator(mdom1, mdom2)
ift.extra.consistency_check(op, dtype, dtype)
op = ift.NullOperator(sp1, mdom2)
ift.extra.consistency_check(op, dtype, dtype)
op = ift.NullOperator(mdom1, sp2)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces
+ _pow_spaces,
[np.float64, np.complex128]))
def testMultiAdaptor(self, sp, dtype):
mdom = ift.MultiDomain.make({'a':sp})
op = ift.MultiAdaptor(mdom)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_p_RG_spaces,
[np.float64, np.complex128]))
def testHarmonicSmoothingOperator(self, sp, dtype):
op = ift.HarmonicSmoothingOperator(sp, 0.1)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces
+ _pow_spaces,
[np.float64, np.complex128]))
def testDOFDistributor(self, sp, dtype):
#TODO: Test for DomainTuple
if sp.size < 4:
return
dofdex = np.arange(sp.size).reshape(sp.shape)%3
dofdex = ift.Field.from_global_data(sp, dofdex)
op = ift.DOFDistributor(dofdex)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces, [np.float64, np.complex128]))
def testPPO(self, sp, dtype):
op = ift.PowerDistributor(target=sp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment