Commit 59d7def9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

intermediate state

parent 8bdf38d0
......@@ -606,11 +606,6 @@ class Field(object):
return False
return (self._val == other._val).all()
def isSubsetOf(self, other):
"""Identical to `Field.isEquivalentTo()`. This method is provided for
easier interoperability with `MultiField`."""
return self.isEquivalentTo(other)
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
......@@ -170,8 +170,6 @@ class MetricInversionEnabler(Energy):
self._preconditioner = preconditioner
def at(self, position):
if self._position.isSubsetOf(position):
return self
return MetricInversionEnabler(
self._energy.at(position), self._controller, self._preconditioner)
......
......@@ -34,8 +34,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(operators, tuple):
raise TypeError("tuple expected")
self._domain = domain
self._ops = tuple(operators[key] for key in self.domain.keys())
self._ops = operators
self._cap = self._all_ops
for op in self._ops:
if op is not None:
......@@ -64,15 +68,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op):
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = {key: v1*v2
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
res = tuple(v1*v2 for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = {}
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops):
res[key] = SumOperator.make([v1, v2], [selfneg, opneg])
res = tuple(SumOperator.make([v1, v2], [selfneg, opneg])
for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res)
......@@ -191,26 +191,45 @@ class MultiField(object):
return False
return True
def isSubsetOf(self, other):
"""Determines (as quickly as possible) whether `self`'s content is
a subset of `other`'s content."""
if self is other:
return True
if not isinstance(other, MultiField):
return False
if len(set(self._domain.keys()) - set(other._domain.keys())) > 0:
return False
for key in self._domain.keys():
if other._domain[key] is not self._domain[key]:
return False
if not other[key].isSubsetOf(self[key]):
return False
return True
for op in ["__add__", "__radd__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = []
for v1, v2 in zip(self._val, other._val):
if v1 is not None:
val.append(v1 if v2 is None else (v1+v2))
else:
val.append(None if v2 is None else v2)
val = tuple(val)
else:
val = tuple(other if v1 is None else (v1+other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
for op in ["__mul__", "__rmul__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(None if v1 is None or v2 is None else v1*v2
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(None if v1 is None else (v1*other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__sub__", "__rsub__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
......@@ -218,27 +237,18 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
res = []
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
for v1, v2 in zip(self._val, other._val):
if v1 is not None:
if v2 is None:
res.append(getattr(v1, op)(v1*0))
else:
res.append(getattr(v1, op)(v2))
else:
if v2 is None:
res.append(None)
else:
res.append(getattr(v2*0, op)(v2))
return MultiField(self._domain, tuple(res))
val = tuple(getattr(v1, op)(v2)
for v1, v2 in zip (self._val, other._val))
else:
return self._transform(lambda x: getattr(x, op)(other))
val = tuple(getattr(v1, op)(other) for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"__itruediv__", "__ifloordiv__", "__ipow__"]:
def func(op):
......
......@@ -236,7 +236,7 @@ def makeOp(input):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator(
input.domain, {key: makeOp(val) for key, val in input.items()})
input.domain, tuple(makeOp(val) for val in input.values()))
raise NotImplementedError
# Arithmetic functions working on Fields
......
......@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
def test_blockdiagonal(self):
op = ift.BlockDiagonalOperator(
dom, {"d1": ift.ScalingOperator(20., dom["d1"])})
dom, (ift.ScalingOperator(20., dom["d1"]),))
op2 = op*op
ift.extra.consistency_check(op2)
assert_equal(type(op2), ift.BlockDiagonalOperator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment