Commit aea0988c authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge JaxLinearOperators

parent d8af23d6
Pipeline #107326 failed with stages
in 9 minutes and 41 seconds
......@@ -65,7 +65,7 @@ class JaxOperator(Operator):
if is_linearization(x):
res, bwd = self._vjp(x.val.val)
fwd = lambda y: self._fwd(x.val.val, y)
jac = _JaxLinearOperator(self._domain, self._target, fwd, bwd)
jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=bwd)
return x.new(makeField(self._target, _jax2np(res)), jac)
res = _jax2np(self._func(x.val))
if isinstance(res, dict):
......@@ -100,7 +100,7 @@ class JaxOperator(Operator):
return None, JaxOperator(dom, self._target, func2)
def JaxLinearOperator(domain, target, func, domain_dtype):
class JaxLinearOperator(LinearOperator):
"""Wrap a jax function as nifty linear operator.
Parameters
......@@ -117,24 +117,50 @@ def JaxLinearOperator(domain, target, func, domain_dtype):
`MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target.
func_T : callable
The jax function that implements the transposed action of the operator.
If None, jax computes the adjoint. Note that this is *not* the adjoint
action. Default: None.
domain_dtype:
Dtype of the domain. If `domain` is a `MultiDomain`, `domain_dtype` is
supposed to be a dictionary.
Needs to be set if `func_transposed` is None. Otherwise it does not have
an effect. Dtype of the domain. If `domain` is a `MultiDomain`,
`domain_dtype` is supposed to be a dictionary. Default: None.
Note
----
It is the user's responsibility that func is actually a linear function. The
user can double check this with the help of `nifty8.extra.check_linear_operator`.
user can double check this with the help of
`nifty8.extra.check_linear_operator`.
"""
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple):
inp = np.empty(domain.shape, domain_dtype)
else:
inp = {kk: np.empty(domain[kk].shape, domain_dtype[kk]) for kk in domain.keys()}
func_transposed = jax.jit(jax.linear_transpose(func, inp))
return _JaxLinearOperator(domain, target, func, func_transposed)
def __init__(self, domain, target, func, domain_dtype=None, func_T=None):
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if domain_dtype is not None and func_T is None:
if isinstance(domain, DomainTuple):
inp = np.empty(domain.shape, domain_dtype)
else:
inp = {kk: np.empty(domain[kk].shape, domain_dtype[kk]) for kk in domain.keys()}
func_T = jax.jit(jax.linear_transpose(func, inp))
elif domain_dtype is None and func_T is not None:
pass
else:
raise ValueError("Either domain_dtype or func_T have to be not None.")
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._func_T = func_T
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
from ..sugar import makeField
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
return makeField(self._domain, _jax2np(fx))
fx = self._func_T(x.conjugate().val)[0]
return makeField(self._target, _jax2np(fx)).conjugate()
class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
......@@ -198,22 +224,3 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
else:
dt = self._dt
return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt)
class _JaxLinearOperator(LinearOperator):
def __init__(self, domain, target, func, func_transposed):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._func_transposed = func_transposed
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
from ..sugar import makeField
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
return makeField(self._tgt(mode), _jax2np(fx))
fx = self._func_transposed(x.conjugate().val)[0]
return makeField(self._tgt(mode), _jax2np(fx)).conjugate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment