Commit 9628c28d authored by Sebastian Hutschenreuter's avatar Sebastian Hutschenreuter
Browse files

added tests

parent 3c5d3287
...@@ -130,16 +130,14 @@ class Linearization(object): ...@@ -130,16 +130,14 @@ class Linearization(object):
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
if isinstance(other, Linearization): if isinstance(other, Linearization):
return self.new( return self.new(
OuterProduct(self._val, other._val.domain)(other._val), OuterProduct(self._val, other.target)(other._val),
OuterProduct(other._val, self._jac.domain)(self._jac)._myadd( OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct( OuterProduct(self._val, other.target)(other._jac), False))
self._val, other._jac.domain)(other._jac), False))
if np.isscalar(other): if np.isscalar(other):
return self.__mul__(other) return self.__mul__(other)
if isinstance(other, (Field, MultiField)): if isinstance(other, (Field, MultiField)):
return self.new( return self.new(OuterProduct(self._val, other.domain)(other),
OuterProduct(self._val, other._val.domain)(other._val), OuterProduct(self._jac(self._val), other.domain))
OuterProduct(other._val, self._jac.domain)(self._jac))
def vdot(self, other): def vdot(self, other):
from .operators.simple_linear_operators import VdotOperator from .operators.simple_linear_operators import VdotOperator
......
...@@ -144,6 +144,26 @@ class Test_Functionality(unittest.TestCase): ...@@ -144,6 +144,26 @@ class Test_Functionality(unittest.TestCase):
res = m1.outer(m2) res = m1.outer(m2)
assert_allclose(res.to_global_data(), np.full((9, 3,), 1.5)) assert_allclose(res.to_global_data(), np.full((9, 3,), 1.5))
def test_sum(self):
x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace((2, 12,), distances=(0.3,))
m1 = ift.Field(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.sum()
res2 = m2.sum(spaces=1)
assert_allclose(res1, 36)
assert_allclose(res2.to_global_data(), np.full(9, 2*12*0.45))
def test_integrate(self):
x1 = ift.RGSpace((9,), distances=2.)
x2 = ift.RGSpace((2, 12,), distances=(0.3,))
m1 = ift.Field(ift.makeDomain(x1), np.arange(9))
m2 = ift.Field.full(ift.makeDomain((x1, x2)), 0.45)
res1 = m1.integrate()
res2 = m2.integrate(spaces=1)
assert_allclose(res1, 36*2)
assert_allclose(res2.to_global_data(), np.full(9, 2*12*0.45*0.3**2))
def test_dataconv(self): def test_dataconv(self):
s1 = ift.RGSpace((10,)) s1 = ift.RGSpace((10,))
ld = np.arange(ift.dobj.local_shape(s1.shape)[0]) ld = np.arange(ift.dobj.local_shape(s1.shape)[0])
......
...@@ -79,6 +79,9 @@ class Model_Tests(unittest.TestCase): ...@@ -79,6 +79,9 @@ class Model_Tests(unittest.TestCase):
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))) ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")))
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos) ift.extra.check_value_gradient_consistency(model, pos)
pos = ift.from_random("normal", dom)
model = ift.OuterProduct(pos['s1'], ift.makeDomain(space)) #(ift.FieldAdapter(dom, "s2"))
ift.extra.check_value_gradient_consistency(model, pos['s2'])
if isinstance(space, ift.RGSpace): if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)( model = ift.FFTOperator(space)(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")) ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
......
...@@ -257,3 +257,15 @@ class Consistency_Tests(unittest.TestCase): ...@@ -257,3 +257,15 @@ class Consistency_Tests(unittest.TestCase):
def testRegridding(self, domain, shape, space): def testRegridding(self, domain, shape, space):
op = ift.RegriddingOperator(domain, shape, space) op = ift.RegriddingOperator(domain, shape, space)
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
@expand(product([ift.DomainTuple.make((ift.RGSpace((3, 5, 4)),
ift.RGSpace((16,), distances=(7.,))),),
ift.DomainTuple.make(ift.HPSpace(12),)],
[ift.DomainTuple.make((ift.RGSpace((2,)),
ift.GLSpace(10)),),
ift.DomainTuple.make(ift.RGSpace((10, 12), distances=(0.1, 1.)),)]
))
def testOuter(self, fdomain, domain):
f = ift.from_random('normal', fdomain)
op = ift.OuterProduct(f, domain)
ift.extra.consistency_check(op)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment