Commit 9a1c07ed authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more

parent c9e87ba7
......@@ -607,6 +607,13 @@ class Field(object):
return False
return (self._val == other._val).all()
def extract(self, dom):
if dom is not self._domain:
raise ValueError("domain mismatch")
return self
def unite(self, other):
return self + other
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
......@@ -218,6 +218,9 @@ class MultiField(object):
else:
return MultiField.from_dict({key: self[key] for key in subset})
def unite(self, other):
return self.combine((self, other))
@staticmethod
def combine(fields):
res = {}
......
......@@ -38,7 +38,7 @@ class Linearization(object):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
return Linearization(
MultiField.combine((self._val, other._val)),
self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac)))
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac)
......
......@@ -60,14 +60,6 @@ class RelaxedSumOperator(LinearOperator):
self._check_mode(mode)
res = None
for op in self._ops:
if isinstance(x.domain, MultiDomain):
x = x.extract(op._dom(mode))
x = op.apply(x, mode)
if res is None:
res = tmp
else:
if isinstance(x.domain, MultiDomain):
res = MultiField.combine([res, tmp])
else:
res = res + tmp
x = op.apply(x.extract(op._dom(mode)), mode)
res = x if res is None else res.unite(x)
return res
......@@ -250,6 +250,7 @@ def domain_union(domains):
return domains[0]
return MultiDomain.union(domains)
# Arithmetic functions working on Fields
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment