Skip to content
Snippets Groups Projects
Commit 9a1c07ed authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more

parent c9e87ba7
No related branches found
No related tags found
No related merge requests found
...@@ -607,6 +607,13 @@ class Field(object): ...@@ -607,6 +607,13 @@ class Field(object):
return False return False
return (self._val == other._val).all() return (self._val == other._val).all()
def extract(self, dom):
if dom is not self._domain:
raise ValueError("domain mismatch")
return self
def unite(self, other):
return self + other
for op in ["__add__", "__radd__", for op in ["__add__", "__radd__",
"__sub__", "__rsub__", "__sub__", "__rsub__",
......
...@@ -218,6 +218,9 @@ class MultiField(object): ...@@ -218,6 +218,9 @@ class MultiField(object):
else: else:
return MultiField.from_dict({key: self[key] for key in subset}) return MultiField.from_dict({key: self[key] for key in subset})
def unite(self, other):
return self.combine((self, other))
@staticmethod @staticmethod
def combine(fields): def combine(fields):
res = {} res = {}
......
...@@ -38,7 +38,7 @@ class Linearization(object): ...@@ -38,7 +38,7 @@ class Linearization(object):
if isinstance(other, Linearization): if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator from .operators.relaxed_sum_operator import RelaxedSumOperator
return Linearization( return Linearization(
MultiField.combine((self._val, other._val)), self._val.unite(other._val),
RelaxedSumOperator((self._jac, other._jac))) RelaxedSumOperator((self._jac, other._jac)))
if isinstance(other, (int, float, complex, Field, MultiField)): if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac) return Linearization(self._val+other, self._jac)
......
...@@ -60,14 +60,6 @@ class RelaxedSumOperator(LinearOperator): ...@@ -60,14 +60,6 @@ class RelaxedSumOperator(LinearOperator):
self._check_mode(mode) self._check_mode(mode)
res = None res = None
for op in self._ops: for op in self._ops:
if isinstance(x.domain, MultiDomain): x = op.apply(x.extract(op._dom(mode)), mode)
x = x.extract(op._dom(mode)) res = x if res is None else res.unite(x)
x = op.apply(x, mode)
if res is None:
res = tmp
else:
if isinstance(x.domain, MultiDomain):
res = MultiField.combine([res, tmp])
else:
res = res + tmp
return res return res
...@@ -250,6 +250,7 @@ def domain_union(domains): ...@@ -250,6 +250,7 @@ def domain_union(domains):
return domains[0] return domains[0]
return MultiDomain.union(domains) return MultiDomain.union(domains)
# Arithmetic functions working on Fields # Arithmetic functions working on Fields
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment