Commit 46c7f781 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into fix_mpi_kl

parents ff25694b 3490e9ce
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np import numpy as np
from numpy.testing import assert_
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
from .field import Field from .field import Field
from .linearization import Linearization from .linearization import Linearization
from .multi_domain import MultiDomain from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.linear_operator import LinearOperator from .operators.linear_operator import LinearOperator
from .sugar import from_random from .sugar import from_random
...@@ -81,6 +83,38 @@ def _check_linearity(op, domain_dtype, atol, rtol): ...@@ -81,6 +83,38 @@ def _check_linearity(op, domain_dtype, atol, rtol):
_assert_allclose(val1, val2, atol=atol, rtol=rtol) _assert_allclose(val1, val2, atol=atol, rtol=rtol)
def _actual_domain_check(op, domain_dtype=None, inp=None):
needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap:
return
if domain_dtype is not None:
inp = from_random("normal", op.domain, dtype=domain_dtype)
elif inp is None:
raise ValueError('Need to specify either dtype or inp')
assert_(inp.domain is op.domain)
assert_(op(inp).domain is op.target)
def _actual_domain_check_nonlinear(op, loc, target_dtype=np.float64):
assert isinstance(loc, (Field, MultiField))
assert_(loc.domain is op.domain)
lin = Linearization.make_var(loc, False)
reslin = op(lin)
assert_(lin.domain is op.domain)
assert_(lin.target is op.domain)
assert_(lin.val.domain is lin.domain)
assert_(reslin.domain is op.domain)
assert_(reslin.target is op.target)
assert_(reslin.val.domain is reslin.target)
assert_(reslin.target is op.target)
assert_(reslin.jac.domain is reslin.domain)
assert_(reslin.jac.target is reslin.target)
_actual_domain_check(reslin.jac, inp=loc)
_actual_domain_check(reslin.jac.adjoint, domain_dtype=target_dtype)
def _domain_check(op): def _domain_check(op):
for dd in [op.domain, op.target]: for dd in [op.domain, op.target]:
if not isinstance(dd, (DomainTuple, MultiDomain)): if not isinstance(dd, (DomainTuple, MultiDomain)):
...@@ -123,6 +157,10 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -123,6 +157,10 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
if not isinstance(op, LinearOperator): if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.') raise TypeError('This test tests only linear operators.')
_domain_check(op) _domain_check(op)
_actual_domain_check(op, domain_dtype)
_actual_domain_check(op.adjoint, target_dtype)
_actual_domain_check(op.inverse, target_dtype)
_actual_domain_check(op.adjoint.inverse, domain_dtype)
_check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol) _check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol) _check_linearity(op.inverse, target_dtype, atol, rtol)
...@@ -180,6 +218,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100): ...@@ -180,6 +218,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
Tolerance for the check. Tolerance for the check.
""" """
_domain_check(op) _domain_check(op)
_actual_domain_check_nonlinear(op, loc)
for _ in range(ntries): for _ in range(ntries):
lin = op(Linearization.make_var(loc)) lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin) loc2, lin2 = _get_acceptable_location(op, loc, lin)
......
...@@ -187,9 +187,7 @@ class _SlowFieldAdapter(LinearOperator): ...@@ -187,9 +187,7 @@ class _SlowFieldAdapter(LinearOperator):
self._check_input(x, mode) self._check_input(x, mode)
if isinstance(x, MultiField): if isinstance(x, MultiField):
return x[self._name] return x[self._name]
else: return MultiField.from_dict({self._name: x}, domain=self._tgt(mode))
return MultiField.from_dict({self._name: x},
domain=self._tgt(mode))
def __repr__(self): def __repr__(self):
return '_SlowFieldAdapter' return '_SlowFieldAdapter'
...@@ -338,12 +336,17 @@ class PartialExtractor(LinearOperator): ...@@ -338,12 +336,17 @@ class PartialExtractor(LinearOperator):
if self._domain[key] is not self._target[key]: if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._compldomain = MultiDomain.make({kk: self._domain[kk]
for kk in self._domain.keys()
if kk not in self._target.keys()})
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.TIMES:
return x.extract(self._target) return x.extract(self._target)
return MultiField.from_dict({key: x[key] for key in x.domain.keys()}) res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res1 = MultiField.full(self._compldomain, 0.)
return res0.unite(res1)
class MatrixProductOperator(EndomorphicOperator): class MatrixProductOperator(EndomorphicOperator):
...@@ -359,20 +362,19 @@ class MatrixProductOperator(EndomorphicOperator): ...@@ -359,20 +362,19 @@ class MatrixProductOperator(EndomorphicOperator):
`dot()` and `transpose()` in the style of numpy arrays. `dot()` and `transpose()` in the style of numpy arrays.
""" """
def __init__(self, domain, matrix): def __init__(self, domain, matrix):
self._domain = domain self._domain = DomainTuple.make(domain)
shp = self._domain.shape
if len(shp) > 1:
raise TypeError('Only 1D-domain supported yet.')
if matrix.shape != (*shp, *shp):
raise ValueError
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._mat = matrix self._mat = matrix
self._mat_tr = matrix.transpose() self._mat_tr = matrix.transpose().conjugate()
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
res = x.to_global_data() res = x.to_global_data()
if mode == self.TIMES: f = self._mat.dot if mode == self.TIMES else self._mat_tr.dot
res = self._mat.dot(res) res = f(res)
if mode == self.ADJOINT_TIMES:
res = self._mat_tr.dot(res)
return Field.from_global_data(self._domain, res) return Field.from_global_data(self._domain, res)
def __repr__(self):
return "MatrixProductOperator"
...@@ -56,7 +56,7 @@ def test_inverse_gamma(field): ...@@ -56,7 +56,7 @@ def test_inverse_gamma(field):
d = np.random.normal(10, size=space.shape)**2 d = np.random.normal(10, size=space.shape)**2
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.InverseGammaLikelihood(d) energy = ift.InverseGammaLikelihood(d)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-7) ift.extra.check_jacobian_consistency(energy, field, tol=1e-5)
def testPoissonian(field): def testPoissonian(field):
...@@ -86,4 +86,4 @@ def test_bernoulli(field): ...@@ -86,4 +86,4 @@ def test_bernoulli(field):
d = np.random.binomial(1, 0.1, size=space.shape) d = np.random.binomial(1, 0.1, size=space.shape)
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.BernoulliEnergy(d) energy = ift.BernoulliEnergy(d)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6) ift.extra.check_jacobian_consistency(energy, field, tol=1e-5)
...@@ -295,3 +295,34 @@ def testValueInserter(sp, seed): ...@@ -295,3 +295,34 @@ def testValueInserter(sp, seed):
ind.append(np.random.randint(0, ss-1)) ind.append(np.random.randint(0, ss-1))
op = ift.ValueInserter(sp, ind) op = ift.ValueInserter(sp, ind)
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
@pmp('sp', [ift.RGSpace(10)])
@pmp('seed', [12, 3])
def testMatrixProductOperator(sp, seed):
np.random.seed(seed)
mat = np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
ift.extra.consistency_check(op)
mat = mat + 1j*np.random.randn(*sp.shape, *sp.shape)
op = ift.MatrixProductOperator(sp, mat)
ift.extra.consistency_check(op)
@pmp('seed', [12, 3])
def testPartialExtractor(seed):
np.random.seed(seed)
tgt = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
dom = tgt.copy()
dom['c'] = ift.RGSpace(3)
dom = ift.MultiDomain.make(dom)
tgt = ift.MultiDomain.make(tgt)
op = ift.PartialExtractor(dom, tgt)
ift.extra.consistency_check(op)
@pmp('seed', [12, 3])
def testSlowFieldAdapter(seed):
dom = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
op = ift.operators.simple_linear_operators._SlowFieldAdapter(dom, 'a')
ift.extra.consistency_check(op)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment