Commit 74737cd0 authored by Martin Reinecke's avatar Martin Reinecke

tentative fix

parent 12ed12b9
Pipeline #43316 failed with stages
in 3 minutes and 21 seconds
...@@ -109,7 +109,7 @@ class Linearization(object): ...@@ -109,7 +109,7 @@ class Linearization(object):
def __getitem__(self, name): def __getitem__(self, name):
from .operators.simple_linear_operators import ducktape from .operators.simple_linear_operators import ducktape
return self.new(self._val[name], ducktape(None, self.target, name)) return self.new(self._val[name], self._jac.ducktape_left(name))
def __neg__(self): def __neg__(self):
return self.new(-self._val, -self._jac, return self.new(-self._val, -self._jac,
......
...@@ -118,13 +118,47 @@ class FieldAdapter(LinearOperator): ...@@ -118,13 +118,47 @@ class FieldAdapter(LinearOperator):
return MultiField(self._tgt(mode), (x,)) return MultiField(self._tgt(mode), (x,))
def __repr__(self): def __repr__(self):
key = self.domain.keys()[0] return 'FieldAdapter'
return 'FieldAdapter: {}'.format(key)
class _SlowFieldAdapter(LinearOperator):
"""Operator for conversion between Fields and MultiFields.
The operator is built so that the MultiDomain is always the target.
Its domain is `tgt[name]`
Parameters
----------
dom : dict or MultiDomain:
the operator's dom
name : String
The relevant key of the MultiDomain.
"""
def __init__(self, dom, name):
from ..sugar import makeDomain
tmp = makeDomain(dom)
if not isinstance(tmp, MultiDomain):
raise TypeError("MultiDomain expected")
self._name = str(name)
self._domain = tmp
self._target = tmp[name]
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if isinstance(x, MultiField):
return x[self._name]
else:
return MultiField.from_dict(self._tgt(mode), {self._name: x})
def __repr__(self):
return '_SlowFieldAdapter'
def ducktape(left, right, name): def ducktape(left, right, name):
"""Convenience function creating an operator that converts between a """Convenience function creating an operator that converts between a
DomainTuple and a single-entry MultiDomain. DomainTuple and a MultiDomain.
Parameters Parameters
---------- ----------
...@@ -146,18 +180,12 @@ def ducktape(left, right, name): ...@@ -146,18 +180,12 @@ def ducktape(left, right, name):
- `left` and `right` must not be both `None`, but one of them can (and - `left` and `right` must not be both `None`, but one of them can (and
probably should) be `None`. In this case, the missing information is probably should) be `None`. In this case, the missing information is
inferred. inferred.
- the returned operator's domains are
- a `DomainTuple` and
- a `MultiDomain` with exactly one entry called `name` and the same
`DomainTuple`
Which of these is the domain and which is the target depends on the
input.
Returns Returns
------- -------
FieldAdapter : an adapter operator converting between the two (possibly FieldAdapter or _SlowFieldAdapter
partially inferred) domains. an adapter operator converting between the two (possibly
partially inferred) domains.
""" """
from ..sugar import makeDomain from ..sugar import makeDomain
from .operator import Operator from .operator import Operator
...@@ -170,12 +198,30 @@ def ducktape(left, right, name): ...@@ -170,12 +198,30 @@ def ducktape(left, right, name):
left = right[name] left = right[name]
else: else:
left = MultiDomain.make({name: right}) left = MultiDomain.make({name: right})
else: else: # need to infer right from left
if isinstance(left, Operator): if isinstance(left, Operator):
left = left.domain left = left.domain
else: else:
left = makeDomain(left) left = makeDomain(left)
return FieldAdapter(left, name) if isinstance(left, MultiDomain):
right = left[name]
else:
right = MultiDomain.make({name: left})
lmulti = isinstance(left, MultiDomain)
rmulti = isinstance(right, MultiDomain)
if lmulti+rmulti != 1:
raise ValueError("need exactly one MultiDomain")
if lmulti:
if len(left) == 1:
return FieldAdapter(left, name)
else:
return _SlowFieldAdapter(left, name).adjoint
if rmulti:
if len(right) == 1:
return FieldAdapter(right, name)
else:
return _SlowFieldAdapter(right, name)
raise ValueError("must not arrive here")
class GeometryRemover(LinearOperator): class GeometryRemover(LinearOperator):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment