Commit 67e0951c authored by Philipp Arras's avatar Philipp Arras
Browse files

Use linear_transpose for adjoint of JaxLinearOperator

parent 00000714
......@@ -130,10 +130,10 @@ def JaxLinearOperator(domain, target, func, domain_dtype):
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple):
inp = np.ones(domain.shape, domain_dtype)
inp = np.empty(domain.shape, domain_dtype)
else:
inp = {kk: np.ones(domain[kk].shape, domain_dtype[kk]) for kk in domain.keys()}
func_transposed = jax.jit(jax.vjp(func, inp)[1])
inp = {kk: np.empty(domain[kk].shape, domain_dtype[kk]) for kk in domain.keys()}
func_transposed = jax.jit(jax.linear_transpose(func, inp))
return _JaxLinearOperator(domain, target, func, func_transposed)
......
......@@ -46,7 +46,7 @@ def test_jax(dom, func):
if linear:
ift.extra.check_linear_operator(op)
else:
with pytest.raises(AssertionError):
with pytest.raises(Exception):
ift.extra.check_linear_operator(op)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment