Commit 46c7f781 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into fix_mpi_kl

parents ff25694b 3490e9ce
......@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from numpy.testing import assert_
from .domain_tuple import DomainTuple
from .field import Field
from .linearization import Linearization
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.linear_operator import LinearOperator
from .sugar import from_random
......@@ -81,6 +83,38 @@ def _check_linearity(op, domain_dtype, atol, rtol):
_assert_allclose(val1, val2, atol=atol, rtol=rtol)
def _actual_domain_check(op, domain_dtype=None, inp=None):
needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap:
if domain_dtype is not None:
inp = from_random("normal", op.domain, dtype=domain_dtype)
elif inp is None:
raise ValueError('Need to specify either dtype or inp')
assert_(inp.domain is op.domain)
assert_(op(inp).domain is
def _actual_domain_check_nonlinear(op, loc, target_dtype=np.float64):
assert isinstance(loc, (Field, MultiField))
assert_(loc.domain is op.domain)
lin = Linearization.make_var(loc, False)
reslin = op(lin)
assert_(lin.domain is op.domain)
assert_( is op.domain)
assert_(lin.val.domain is lin.domain)
assert_(reslin.domain is op.domain)
assert_( is
assert_(reslin.val.domain is
assert_( is
assert_(reslin.jac.domain is reslin.domain)
assert_( is
_actual_domain_check(reslin.jac, inp=loc)
_actual_domain_check(reslin.jac.adjoint, domain_dtype=target_dtype)
def _domain_check(op):
for dd in [op.domain,]:
if not isinstance(dd, (DomainTuple, MultiDomain)):
......@@ -123,6 +157,10 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.')
_actual_domain_check(op, domain_dtype)
_actual_domain_check(op.adjoint, target_dtype)
_actual_domain_check(op.inverse, target_dtype)
_actual_domain_check(op.adjoint.inverse, domain_dtype)
_check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol)
......@@ -180,6 +218,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
Tolerance for the check.
_actual_domain_check_nonlinear(op, loc)
for _ in range(ntries):
lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin)
......@@ -187,9 +187,7 @@ class _SlowFieldAdapter(LinearOperator):
self._check_input(x, mode)
if isinstance(x, MultiField):
return x[self._name]
return MultiField.from_dict({self._name: x},
return MultiField.from_dict({self._name: x}, domain=self._tgt(mode))
def __repr__(self):
return '_SlowFieldAdapter'
......@@ -338,12 +336,17 @@ class PartialExtractor(LinearOperator):
if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES
self._compldomain = MultiDomain.make({kk: self._domain[kk]
for kk in self._domain.keys()
if kk not in self._target.keys()})
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res1 = MultiField.full(self._compldomain, 0.)
return res0.unite(res1)
class MatrixProductOperator(EndomorphicOperator):
......@@ -359,20 +362,19 @@ class MatrixProductOperator(EndomorphicOperator):
`dot()` and `transpose()` in the style of numpy arrays.
def __init__(self, domain, matrix):
self._domain = domain
self._domain = DomainTuple.make(domain)
shp = self._domain.shape
if len(shp) > 1:
raise TypeError('Only 1D-domain supported yet.')
if matrix.shape != (*shp, *shp):
raise ValueError
self._capability = self.TIMES | self.ADJOINT_TIMES
self._mat = matrix
self._mat_tr = matrix.transpose()
self._mat_tr = matrix.transpose().conjugate()
def apply(self, x, mode):
self._check_input(x, mode)
res = x.to_global_data()
if mode == self.TIMES:
res =
if mode == self.ADJOINT_TIMES:
res =
f = if mode == self.TIMES else
res = f(res)
return Field.from_global_data(self._domain, res)
def __repr__(self):
return "MatrixProductOperator"
......@@ -56,7 +56,7 @@ def test_inverse_gamma(field):
d = np.random.normal(10, size=space.shape)**2
d = ift.Field.from_global_data(space, d)
energy = ift.InverseGammaLikelihood(d)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-7)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-5)
def testPoissonian(field):
......@@ -86,4 +86,4 @@ def test_bernoulli(field):
d = np.random.binomial(1, 0.1, size=space.shape)
d = ift.Field.from_global_data(space, d)
energy = ift.BernoulliEnergy(d)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-5)
......@@ -295,3 +295,34 @@ def testValueInserter(sp, seed):
ind.append(np.random.randint(0, ss-1))
op = ift.ValueInserter(sp, ind)
@pmp('sp', [ift.RGSpace(10)])
@pmp('seed', [12, 3])
def testMatrixProductOperator(sp, seed):
mat = np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
mat = mat + 1j*np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
@pmp('seed', [12, 3])
def testPartialExtractor(seed):
tgt = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
dom = tgt.copy()
dom['c'] = ift.RGSpace(3)
dom = ift.MultiDomain.make(dom)
tgt = ift.MultiDomain.make(tgt)
op = ift.PartialExtractor(dom, tgt)
@pmp('seed', [12, 3])
def testSlowFieldAdapter(seed):
dom = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
op = ift.operators.simple_linear_operators._SlowFieldAdapter(dom, 'a')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment