Commit 2cc75057 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch '16-add-operator-tests' into 'NIFTy_5'

Resolve "Create consistency checks for all linear operators and gradient consistency checks for all energies in library"

See merge request ift/nifty-dev!60
parents fcbd1ea9 9a7e60bf
...@@ -42,6 +42,7 @@ from .operators.laplace_operator import LaplaceOperator ...@@ -42,6 +42,7 @@ from .operators.laplace_operator import LaplaceOperator
from .operators.linear_operator import LinearOperator from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator from .operators.mask_operator import MaskOperator
from .operators.multi_adaptor import MultiAdaptor from .operators.multi_adaptor import MultiAdaptor
from .operators.null_operator import NullOperator
from .operators.power_distributor import PowerDistributor from .operators.power_distributor import PowerDistributor
from .operators.qht_operator import QHTOperator from .operators.qht_operator import QHTOperator
from .operators.sampling_enabler import SamplingEnabler from .operators.sampling_enabler import SamplingEnabler
......
...@@ -158,11 +158,11 @@ class data_object(object): ...@@ -158,11 +158,11 @@ class data_object(object):
def prod(self, axis=None): def prod(self, axis=None):
return self._contraction_helper("prod", MPI.PROD, axis) return self._contraction_helper("prod", MPI.PROD, axis)
def min(self, axis=None): # def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis) # return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None): # def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis) # return self._contraction_helper("max", MPI.MAX, axis)
def mean(self, axis=None): def mean(self, axis=None):
if axis is None: if axis is None:
...@@ -348,6 +348,12 @@ def np_allreduce_min(arr): ...@@ -348,6 +348,12 @@ def np_allreduce_min(arr):
return res return res
def np_allreduce_max(arr):
res = np.empty_like(arr)
_comm.Allreduce(arr, res, MPI.MAX)
return res
def distaxis(arr): def distaxis(arr):
return arr._distaxis return arr._distaxis
......
...@@ -75,6 +75,10 @@ def np_allreduce_min(arr): ...@@ -75,6 +75,10 @@ def np_allreduce_min(arr):
return arr return arr
def np_allreduce_max(arr):
return arr
def distaxis(arr): def distaxis(arr):
return -1 return -1
......
...@@ -33,6 +33,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -33,6 +33,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"empty", "zeros", "ones", "empty_like", "vdot", "exp", "empty", "zeros", "ones", "empty_like", "vdot", "exp",
"log", "tanh", "sqrt", "from_object", "from_random", "log", "tanh", "sqrt", "from_object", "from_random",
"local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum", "local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data", "distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "redistribute", "default_distaxis", "is_numpy",
"lock", "locked"] "lock", "locked"]
...@@ -481,39 +481,39 @@ class Field(object): ...@@ -481,39 +481,39 @@ class Field(object):
def any(self, spaces=None): def any(self, spaces=None):
return self._contraction_helper('any', spaces) return self._contraction_helper('any', spaces)
def min(self, spaces=None): # def min(self, spaces=None):
"""Determines the minimum over the sub-domains given by `spaces`. # """Determines the minimum over the sub-domains given by `spaces`.
#
Parameters # Parameters
---------- # ----------
spaces : None, int or tuple of int (default: None) # spaces : None, int or tuple of int (default: None)
The operation is only carried out over the sub-domains in this # The operation is only carried out over the sub-domains in this
tuple. If None, it is carried out over all sub-domains. # tuple. If None, it is carried out over all sub-domains.
#
Returns # Returns
------- # -------
Field or scalar # Field or scalar
The result of the operation. If it is carried out over the entire # The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field. # domain, this is a scalar, otherwise a Field.
""" # """
return self._contraction_helper('min', spaces) # return self._contraction_helper('min', spaces)
#
def max(self, spaces=None): # def max(self, spaces=None):
"""Determines the maximum over the sub-domains given by `spaces`. # """Determines the maximum over the sub-domains given by `spaces`.
#
Parameters # Parameters
---------- # ----------
spaces : None, int or tuple of int (default: None) # spaces : None, int or tuple of int (default: None)
The operation is only carried out over the sub-domains in this # The operation is only carried out over the sub-domains in this
tuple. If None, it is carried out over all sub-domains. # tuple. If None, it is carried out over all sub-domains.
#
Returns # Returns
------- # -------
Field or scalar # Field or scalar
The result of the operation. If it is carried out over the entire # The result of the operation. If it is carried out over the entire
domain, this is a scalar, otherwise a Field. # domain, this is a scalar, otherwise a Field.
""" # """
return self._contraction_helper('max', spaces) # return self._contraction_helper('max', spaces)
def mean(self, spaces=None): def mean(self, spaces=None):
"""Determines the mean over the sub-domains given by `spaces`. """Determines the mean over the sub-domains given by `spaces`.
......
...@@ -41,9 +41,9 @@ class DOFDistributor(LinearOperator): ...@@ -41,9 +41,9 @@ class DOFDistributor(LinearOperator):
---------- ----------
dofdex: Field of integers dofdex: Field of integers
An integer Field on exactly one Space establishing the association An integer Field on exactly one Space establishing the association
between the locations in the operators target and the underlying between the locations in the operator's target and the underlying
degrees of freedom in its domain. degrees of freedom in its domain.
It has to start at 0 and it increases monotonicly, no empty bins are It has to start at 0 and it increases monotonically, no empty bins are
allowed. allowed.
target: Domain, tuple of Domain, or DomainTuple, optional target: Domain, tuple of Domain, or DomainTuple, optional
The target of the operator. If not specified, the domain of the dofdex The target of the operator. If not specified, the domain of the dofdex
...@@ -70,12 +70,17 @@ class DOFDistributor(LinearOperator): ...@@ -70,12 +70,17 @@ class DOFDistributor(LinearOperator):
if partner != dofdex.domain[0]: if partner != dofdex.domain[0]:
raise ValueError("incorrect dofdex domain") raise ValueError("incorrect dofdex domain")
nbin = dofdex.max() ldat = dofdex.local_data
if ldat.size==0: # can happen for weird configurations
nbin = 0
else:
nbin = ldat.max()
nbin = dobj.np_allreduce_max(np.array(nbin))[()] + 1
if partner.scalar_dvol is not None: if partner.scalar_dvol is not None:
wgt = np.bincount(dofdex.local_data.ravel(), minlength=nbin) wgt = np.bincount(dofdex.local_data.ravel(), minlength=nbin)
wgt = wgt*partner.scalar_dvol wgt = wgt*partner.scalar_dvol
else: else:
dvol = dobj.local_data(partner.dvol) dvol = Field.from_global_data(partner, partner.dvol).local_data
wgt = np.bincount(dofdex.local_data.ravel(), wgt = np.bincount(dofdex.local_data.ravel(),
minlength=nbin, weights=dvol) minlength=nbin, weights=dvol)
# The explicit conversion to float64 is necessary because bincount # The explicit conversion to float64 is necessary because bincount
......
...@@ -76,7 +76,7 @@ class SlopeOperator(LinearOperator): ...@@ -76,7 +76,7 @@ class SlopeOperator(LinearOperator):
return Field.from_global_data(self.target, res) return Field.from_global_data(self.target, res)
# Adjoint times # Adjoint times
res = np.zeros(self.domain[0].shape) res = np.zeros(self.domain[0].shape, dtype=x.dtype)
xglob = x.to_global_data() xglob = x.to_global_data()
res[-1] = np.sum(xglob) * self._sigmas[-1] res[-1] = np.sum(xglob) * self._sigmas[-1]
for i in range(self.ndim): for i in range(self.ndim):
......
...@@ -154,8 +154,8 @@ class Test_Functionality(unittest.TestCase): ...@@ -154,8 +154,8 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f1.local_data, +f1.local_data) assert_equal(f1.local_data, +f1.local_data)
assert_equal(f1.sum(), f1.sum(0)) assert_equal(f1.sum(), f1.sum(0))
f1 = ift.from_global_data(s1, np.arange(10)) f1 = ift.from_global_data(s1, np.arange(10))
assert_equal(f1.min(), 0) # assert_equal(f1.min(), 0)
assert_equal(f1.max(), 9) # assert_equal(f1.max(), 9)
assert_equal(f1.prod(), 0) assert_equal(f1.prod(), 0)
def test_weight(self): def test_weight(self):
......
...@@ -35,6 +35,93 @@ _pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))] ...@@ -35,6 +35,93 @@ _pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))]
class Consistency_Tests(unittest.TestCase): class Consistency_Tests(unittest.TestCase):
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testOperatorCombinations(self, sp, dtype):
a = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
b = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
op = ift.SandwichOperator.make(a, b)
ift.extra.consistency_check(op, dtype, dtype)
op = a*b
ift.extra.consistency_check(op, dtype, dtype)
op = a+b
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
harmonic=True), 3, 0),
(ift.LMSpace(4), 10, 0)],
[np.float64, np.complex128]))
def testSlopeOperator(self, args, dtype):
tmp = ift.ExpTransform(ift.PowerSpace(args[0]), args[1], args[2])
tgt = tmp.domain[0]
sig = np.array([0.3, 0.13])
dom = ift.UnstructuredDomain(2)
op = ift.SlopeOperator(dom, tgt, sig)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testSelectionOperator(self, sp1, sp2, dtype):
mdom = ift.MultiDomain.make({'a':sp1, 'b':sp2})
op = ift.SelectionOperator(mdom, 'a')
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testOperatorAdaptor(self, sp, dtype):
op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
ift.extra.consistency_check(op.adjoint, dtype, dtype)
ift.extra.consistency_check(op.inverse, dtype, dtype)
ift.extra.consistency_check(op.inverse.adjoint, dtype, dtype)
ift.extra.consistency_check(op.adjoint.inverse, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testNullOperator(self, sp1, sp2, dtype):
op = ift.NullOperator(sp1, sp2)
ift.extra.consistency_check(op, dtype, dtype)
mdom1 = ift.MultiDomain.make({'a':sp1})
mdom2 = ift.MultiDomain.make({'b':sp2})
op = ift.NullOperator(mdom1, mdom2)
ift.extra.consistency_check(op, dtype, dtype)
op = ift.NullOperator(sp1, mdom2)
ift.extra.consistency_check(op, dtype, dtype)
op = ift.NullOperator(mdom1, sp2)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces
+ _pow_spaces,
[np.float64, np.complex128]))
def testMultiAdaptor(self, sp, dtype):
mdom = ift.MultiDomain.make({'a':sp})
op = ift.MultiAdaptor(mdom)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_p_RG_spaces,
[np.float64, np.complex128]))
def testHarmonicSmoothingOperator(self, sp, dtype):
op = ift.HarmonicSmoothingOperator(sp, 0.1)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces
+ _pow_spaces,
[np.float64, np.complex128]))
def testDOFDistributor(self, sp, dtype):
#TODO: Test for DomainTuple
if sp.size < 4:
return
dofdex = np.arange(sp.size).reshape(sp.shape)%3
dofdex = ift.Field.from_global_data(sp, dofdex)
op = ift.DOFDistributor(dofdex)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces, [np.float64, np.complex128])) @expand(product(_h_spaces, [np.float64, np.complex128]))
def testPPO(self, sp, dtype): def testPPO(self, sp, dtype):
op = ift.PowerDistributor(target=sp) op = ift.PowerDistributor(target=sp)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment