Commit ef2848d3 authored by Philipp Arras's avatar Philipp Arras
Browse files

Jax: Actually return derivative and not tuple of derivative

parent 6b7a0e59
Pipeline #108115 passed with stages
in 20 minutes and 59 seconds
......@@ -68,7 +68,7 @@ class JaxOperator(Operator):
if is_linearization(x):
res, bwd = self._vjp(x.val.val)
fwd = lambda y: self._fwd(x.val.val, y)
jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=bwd)
jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=lambda x: bwd(x)[0])
return x.new(makeField(self._target, _jax2np(res)), jac)
res = _jax2np(self._func(x.val))
if isinstance(res, dict):
......@@ -145,7 +145,7 @@ class JaxLinearOperator(LinearOperator):
else:
inp = {kk: SimpleNamespace(shape=domain[kk].shape, dtype=domain_dtype[kk])
for kk in domain.keys()}
func_T = jax.jit(jax.linear_transpose(func, inp))
func_T = jax.jit(lambda x: jax.linear_transpose(func, inp)(x)[0])
elif domain_dtype is None and func_T is not None:
pass
else:
......@@ -162,7 +162,7 @@ class JaxLinearOperator(LinearOperator):
if mode == self.TIMES:
fx = self._func(x.val)
return makeField(self._target, _jax2np(fx))
fx = self._func_T(x.conjugate().val)[0]
fx = self._func_T(x.conjugate().val)
return makeField(self._domain, _jax2np(fx)).conjugate()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment