Commit 9cdba817 authored by Martin Reinecke's avatar Martin Reinecke

next round of fixes; should work now (tm)

parent 74737cd0
Pipeline #43323 passed with stages
in 8 minutes and 3 seconds
...@@ -189,20 +189,20 @@ def ducktape(left, right, name): ...@@ -189,20 +189,20 @@ def ducktape(left, right, name):
""" """
from ..sugar import makeDomain from ..sugar import makeDomain
from .operator import Operator from .operator import Operator
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(left, Operator):
left = left.domain
elif left is not None:
left = makeDomain(left)
if left is None: # need to infer left from right if left is None: # need to infer left from right
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(right, MultiDomain): if isinstance(right, MultiDomain):
left = right[name] left = right[name]
else: else:
left = MultiDomain.make({name: right}) left = MultiDomain.make({name: right})
else: # need to infer right from left elif right is None: # need to infer right from left
if isinstance(left, Operator):
left = left.domain
else:
left = makeDomain(left)
if isinstance(left, MultiDomain): if isinstance(left, MultiDomain):
right = left[name] right = left[name]
else: else:
...@@ -218,7 +218,7 @@ def ducktape(left, right, name): ...@@ -218,7 +218,7 @@ def ducktape(left, right, name):
return _SlowFieldAdapter(left, name).adjoint return _SlowFieldAdapter(left, name).adjoint
if rmulti: if rmulti:
if len(right) == 1: if len(right) == 1:
return FieldAdapter(right, name) return FieldAdapter(left, name)
else: else:
return _SlowFieldAdapter(right, name) return _SlowFieldAdapter(right, name)
raise ValueError("must not arrive here") raise ValueError("must not arrive here")
......
...@@ -61,8 +61,8 @@ def testBinary(type1, type2, space, seed): ...@@ -61,8 +61,8 @@ def testBinary(type1, type2, space, seed):
_make_linearization(type2, dom2, seed) _make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2)) dom = ift.MultiDomain.union((dom1, dom2))
select_s1 = ift.ducktape(None, dom, "s1") select_s1 = ift.ducktape(None, dom1, "s1")
select_s2 = ift.ducktape(None, dom, "s2") select_s2 = ift.ducktape(None, dom2, "s2")
model = select_s1*select_s2 model = select_s1*select_s2
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_jacobian_consistency(model, pos, ntries=20) ift.extra.check_jacobian_consistency(model, pos, ntries=20)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment