Commit 59d7def9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

intermediate state

parent 8bdf38d0
...@@ -606,11 +606,6 @@ class Field(object): ...@@ -606,11 +606,6 @@ class Field(object):
return False return False
return (self._val == other._val).all() return (self._val == other._val).all()
def isSubsetOf(self, other):
"""Identical to `Field.isEquivalentTo()`. This method is provided for
easier interoperability with `MultiField`."""
return self.isEquivalentTo(other)
for op in ["__add__", "__radd__", for op in ["__add__", "__radd__",
"__sub__", "__rsub__", "__sub__", "__rsub__",
......
...@@ -170,8 +170,6 @@ class MetricInversionEnabler(Energy): ...@@ -170,8 +170,6 @@ class MetricInversionEnabler(Energy):
self._preconditioner = preconditioner self._preconditioner = preconditioner
def at(self, position): def at(self, position):
if self._position.isSubsetOf(position):
return self
return MetricInversionEnabler( return MetricInversionEnabler(
self._energy.at(position), self._controller, self._preconditioner) self._energy.at(position), self._controller, self._preconditioner)
......
...@@ -34,8 +34,12 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -34,8 +34,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
LinearOperators as items LinearOperators as items
""" """
super(BlockDiagonalOperator, self).__init__() super(BlockDiagonalOperator, self).__init__()
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
if not isinstance(operators, tuple):
raise TypeError("tuple expected")
self._domain = domain self._domain = domain
self._ops = tuple(operators[key] for key in self.domain.keys()) self._ops = operators
self._cap = self._all_ops self._cap = self._all_ops
for op in self._ops: for op in self._ops:
if op is not None: if op is not None:
...@@ -64,15 +68,13 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -64,15 +68,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op): def _combine_chain(self, op):
if self._domain is not op._domain: if self._domain is not op._domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
res = {key: v1*v2 res = tuple(v1*v2 for v1, v2 in zip(self._ops, op._ops))
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res) return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg): def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator from ..operators.sum_operator import SumOperator
if self._domain is not op._domain: if self._domain is not op._domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
res = {} res = tuple(SumOperator.make([v1, v2], [selfneg, opneg])
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops): for v1, v2 in zip(self._ops, op._ops))
res[key] = SumOperator.make([v1, v2], [selfneg, opneg])
return BlockDiagonalOperator(self._domain, res) return BlockDiagonalOperator(self._domain, res)
...@@ -191,26 +191,45 @@ class MultiField(object): ...@@ -191,26 +191,45 @@ class MultiField(object):
return False return False
return True return True
def isSubsetOf(self, other):
"""Determines (as quickly as possible) whether `self`'s content is for op in ["__add__", "__radd__"]:
a subset of `other`'s content.""" def func(op):
if self is other: def func2(self, other):
return True if isinstance(other, MultiField):
if not isinstance(other, MultiField): if self._domain is not other._domain:
return False raise ValueError("domain mismatch")
if len(set(self._domain.keys()) - set(other._domain.keys())) > 0: val = []
return False for v1, v2 in zip(self._val, other._val):
for key in self._domain.keys(): if v1 is not None:
if other._domain[key] is not self._domain[key]: val.append(v1 if v2 is None else (v1+v2))
return False else:
if not other[key].isSubsetOf(self[key]): val.append(None if v2 is None else v2)
return False val = tuple(val)
return True else:
val = tuple(other if v1 is None else (v1+other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__add__", "__radd__", for op in ["__mul__", "__rmul__"]:
"__sub__", "__rsub__", def func(op):
"__mul__", "__rmul__", def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(None if v1 is None or v2 is None else v1*v2
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(None if v1 is None else (v1*other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__sub__", "__rsub__",
"__div__", "__rdiv__", "__div__", "__rdiv__",
"__truediv__", "__rtruediv__", "__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__", "__floordiv__", "__rfloordiv__",
...@@ -218,27 +237,18 @@ for op in ["__add__", "__radd__", ...@@ -218,27 +237,18 @@ for op in ["__add__", "__radd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op): def func(op):
def func2(self, other): def func2(self, other):
res = []
if isinstance(other, MultiField): if isinstance(other, MultiField):
if self._domain is not other._domain: if self._domain is not other._domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
for v1, v2 in zip(self._val, other._val): val = tuple(getattr(v1, op)(v2)
if v1 is not None: for v1, v2 in zip (self._val, other._val))
if v2 is None:
res.append(getattr(v1, op)(v1*0))
else:
res.append(getattr(v1, op)(v2))
else:
if v2 is None:
res.append(None)
else:
res.append(getattr(v2*0, op)(v2))
return MultiField(self._domain, tuple(res))
else: else:
return self._transform(lambda x: getattr(x, op)(other)) val = tuple(getattr(v1, op)(other) for v1 in self._val)
return MultiField(self._domain, val)
return func2 return func2
setattr(MultiField, op, func(op)) setattr(MultiField, op, func(op))
for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"__itruediv__", "__ifloordiv__", "__ipow__"]: "__itruediv__", "__ifloordiv__", "__ipow__"]:
def func(op): def func(op):
......
...@@ -236,7 +236,7 @@ def makeOp(input): ...@@ -236,7 +236,7 @@ def makeOp(input):
return DiagonalOperator(input) return DiagonalOperator(input)
if isinstance(input, MultiField): if isinstance(input, MultiField):
return BlockDiagonalOperator( return BlockDiagonalOperator(
input.domain, {key: makeOp(val) for key, val in input.items()}) input.domain, tuple(makeOp(val) for val in input.values()))
raise NotImplementedError raise NotImplementedError
# Arithmetic functions working on Fields # Arithmetic functions working on Fields
......
...@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase): ...@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
def test_blockdiagonal(self): def test_blockdiagonal(self):
op = ift.BlockDiagonalOperator( op = ift.BlockDiagonalOperator(
dom, {"d1": ift.ScalingOperator(20., dom["d1"])}) dom, (ift.ScalingOperator(20., dom["d1"]),))
op2 = op*op op2 = op*op
ift.extra.consistency_check(op2) ift.extra.consistency_check(op2)
assert_equal(type(op2), ift.BlockDiagonalOperator) assert_equal(type(op2), ift.BlockDiagonalOperator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment