From 9cdba8173cf359244f7d98a5e191b364f65525c0 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Tue, 5 Feb 2019 13:11:36 +0100 Subject: [PATCH] next round of fixes; should work now (tm) --- nifty5/operators/simple_linear_operators.py | 20 ++++++++++---------- test/test_operators/test_jacobian.py | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 7eff6519..4cae72cf 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -189,20 +189,20 @@ def ducktape(left, right, name): """ from ..sugar import makeDomain from .operator import Operator + if isinstance(right, Operator): + right = right.target + elif right is not None: + right = makeDomain(right) + if isinstance(left, Operator): + left = left.domain + elif left is not None: + left = makeDomain(left) if left is None: # need to infer left from right - if isinstance(right, Operator): - right = right.target - elif right is not None: - right = makeDomain(right) if isinstance(right, MultiDomain): left = right[name] else: left = MultiDomain.make({name: right}) - else: # need to infer right from left - if isinstance(left, Operator): - left = left.domain - else: - left = makeDomain(left) + elif right is None: # need to infer right from left if isinstance(left, MultiDomain): right = left[name] else: @@ -218,7 +218,7 @@ def ducktape(left, right, name): return _SlowFieldAdapter(left, name).adjoint if rmulti: if len(right) == 1: - return FieldAdapter(right, name) + return FieldAdapter(left, name) else: return _SlowFieldAdapter(right, name) raise ValueError("must not arrive here") diff --git a/test/test_operators/test_jacobian.py b/test/test_operators/test_jacobian.py index 4b825b06..4bef7b25 100644 --- a/test/test_operators/test_jacobian.py +++ b/test/test_operators/test_jacobian.py @@ -61,8 +61,8 @@ def testBinary(type1, type2, space, seed): _make_linearization(type2, dom2, seed) dom = ift.MultiDomain.union((dom1, dom2)) - select_s1 = ift.ducktape(None, dom, "s1") - select_s2 = ift.ducktape(None, dom, "s2") + select_s1 = ift.ducktape(None, dom1, "s1") + select_s2 = ift.ducktape(None, dom2, "s2") model = select_s1*select_s2 pos = ift.from_random("normal", dom) ift.extra.check_jacobian_consistency(model, pos, ntries=20) -- GitLab