Commit 30293512 authored by Lukas Platz's avatar Lukas Platz
Browse files

revert changes + update docstring

parent 2f1ff8a6
Pipeline #107770 passed with stages
in 21 minutes and 10 seconds
......@@ -122,7 +122,8 @@ class JaxLinearOperator(LinearOperator):
func_T : callable
The jax function that implements the transposed action of the operator.
If None, jax computes the adjoint. Note that this is *not* the adjoint
action. Default: None.
action. Needs to return a tuple with the result as the first entry.
Default: None.
Needs to be set if `func_transposed` is None. Otherwise it does not have
......@@ -147,8 +148,7 @@ class JaxLinearOperator(LinearOperator):
for kk in domain.keys()}
func_T = jax.jit(jax.linear_transpose(func, inp))
elif domain_dtype is None and func_T is not None:
tmp = func_T
func_T = lambda x: (tmp(x),)
raise ValueError("Either domain_dtype or func_T have to be not None.")
self._domain = makeDomain(domain)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment