Commit 59bf1166 authored by Martin Reinecke's avatar Martin Reinecke

fixes

parent 570f8d8a
Pipeline #29515 passed with stages
in 3 minutes and 59 seconds
......@@ -15,7 +15,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
"""
super(BlockDiagonalOperator, self).__init__()
self._operators = operators
self._domain = MultiDomain(
self._domain = MultiDomain.make(
{key: op.domain for key, op in self._operators.items()})
self._cap = self._all_ops
for op in self._operators.values():
......@@ -43,12 +43,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
res = {}
for key in self._operators.keys():
res[key] = self._operators[key]*op._operators[key]
return res
return BlockDiagonalOperator(res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
res = {}
for key in self._operators.keys():
res[key] = SumOperator.make([self._operators[key],
op._operators[key]],
[selfneg, opneg])
return res
return BlockDiagonalOperator(res)
......@@ -121,6 +121,15 @@ class MultiField(object):
return MultiField({key: Field.full(dom, val)
for key, dom in domain.items()})
def to_global_data(self):
return {key: val.to_global_data() for key, val in self._val.items()}
@staticmethod
def from_global_data(domain, arr, sum_up=False):
return MultiField({key: Field.from_global_data(domain[key],
val, sum_up)
for key, val in arr.items()})
def norm(self):
""" Computes the L2-norm of the field values.
......
......@@ -32,7 +32,6 @@ class Test_Functionality(unittest.TestCase):
assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))
def test_lock(self):
s1 = ift.RGSpace((10,))
f1 = ift.full(dom, 27)
assert_equal(f1.locked, False)
f1.lock()
......@@ -42,13 +41,26 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f1.locked_copy() is f1, True)
def test_fill(self):
s1 = ift.RGSpace((10,))
f1 = ift.full(s1, 27)
assert_equal((f1.fill(10) == 10).all(), True)
f1 = ift.full(dom, 27)
f1.fill(10)
for val in f1.values():
assert_equal((val == 10).all(), True)
def test_dataconv(self):
s1 = ift.RGSpace((10,))
ld = np.arange(ift.dobj.local_shape(s1.shape)[0])
gd = np.arange(s1.shape[0])
assert_equal(ld, ift.from_local_data(s1, ld).local_data)
assert_equal(gd, ift.from_global_data(s1, gd).to_global_data())
f1 = ift.full(dom, 27)
f2 = ift.from_global_data(dom, f1.to_global_data())
for key, val in f1.items():
assert_equal(val.local_data, f2[key].local_data)
def test_blockdiagonal(self):
op = ift.BlockDiagonalOperator({"d1": ift.ScalingOperator(20., dom["d1"])})
op2 = op*op
assert_equal(type(op2), ift.BlockDiagonalOperator)
f1 = op2(ift.full(dom, 1))
for val in f1.values():
assert_equal((val == 400).all(), True)
op2 = op+op
assert_equal(type(op2), ift.BlockDiagonalOperator)
f1 = op2(ift.full(dom, 1))
for val in f1.values():
assert_equal((val == 40).all(), True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment