Commit 559b8416 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add JaxLinearOperator

parent 70a698c2
Pipeline #107174 passed with stages
in 21 minutes and 4 seconds
...@@ -23,7 +23,7 @@ from .operator import Operator ...@@ -23,7 +23,7 @@ from .operator import Operator
try: try:
import jax import jax
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
__all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator"] __all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator", "JaxLinearOperator"]
except ImportError: except ImportError:
__all__ = [] __all__ = []
...@@ -65,7 +65,7 @@ class JaxOperator(Operator): ...@@ -65,7 +65,7 @@ class JaxOperator(Operator):
if is_linearization(x): if is_linearization(x):
res, bwd = self._vjp(x.val.val) res, bwd = self._vjp(x.val.val)
fwd = lambda y: self._fwd(x.val.val, y) fwd = lambda y: self._fwd(x.val.val, y)
jac = _JaxJacobian(self._domain, self._target, fwd, bwd) jac = _JaxLinearOperator(self._domain, self._target, fwd, bwd)
return x.new(makeField(self._target, _jax2np(res)), jac) return x.new(makeField(self._target, _jax2np(res)), jac)
res = _jax2np(self._func(x.val)) res = _jax2np(self._func(x.val))
if isinstance(res, dict): if isinstance(res, dict):
...@@ -100,6 +100,43 @@ class JaxOperator(Operator): ...@@ -100,6 +100,43 @@ class JaxOperator(Operator):
return None, JaxOperator(dom, self._target, func2) return None, JaxOperator(dom, self._target, func2)
def JaxLinearOperator(domain, target, func, domain_dtype):
"""Wrap a jax function as nifty linear operator.
Parameters
----------
domain : DomainTuple or MultiDomain
Domain of the operator.
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the
target.
domain_dtype:
Dtype of the domain. If `domain` is a `MultiDomain`, `domain_dtype` is
supposed to be a dictionary.
Note
----
It is the user's responsibility that func is actually a linear function. The
user can double check this with the help of `nifty8.extra.check_linear_operator`.
"""
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple):
inp = np.ones(domain.shape, domain_dtype)
else:
inp = {kk: np.ones(domain[kk].shape, domain_dtype[kk]) for kk in domain.keys()}
func_transposed = jax.jit(jax.vjp(func, inp)[1])
return _JaxLinearOperator(domain, target, func, func_transposed)
class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator): class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
"""Wrap a jax function as nifty likelihood energy operator. """Wrap a jax function as nifty likelihood energy operator.
...@@ -163,7 +200,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator): ...@@ -163,7 +200,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt) return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt)
class _JaxJacobian(LinearOperator): class _JaxLinearOperator(LinearOperator):
def __init__(self, domain, target, func, func_transposed): def __init__(self, domain, target, func, func_transposed):
from ..sugar import makeDomain from ..sugar import makeDomain
self._domain = makeDomain(domain) self._domain = makeDomain(domain)
......
...@@ -29,16 +29,26 @@ pmp = pytest.mark.parametrize ...@@ -29,16 +29,26 @@ pmp = pytest.mark.parametrize
@pmp("dom", [ift.RGSpace((10, 8)), (ift.RGSpace(10), ift.RGSpace(8))]) @pmp("dom", [ift.RGSpace((10, 8)), (ift.RGSpace(10), ift.RGSpace(8))])
@pmp("func", [lambda x: x, lambda x: x**2, lambda x: x*x, lambda x: x*x[0, 0], @pmp("func", [(lambda x: x, True), (lambda x: x**2, False), (lambda x: x*x, False),
lambda x: jnp.sin(x), lambda x: x*x.sum()]) (lambda x: x*x[0, 0], False), (lambda x: x+x[0, 0], True),
(lambda x: jnp.sin(x), False), (lambda x: x*x.sum(), False),
(lambda x: x+x.sum(), True)])
def test_jax(dom, func): def test_jax(dom, func):
pytest.importorskip("jax") pytest.importorskip("jax")
loc = ift.from_random(dom) loc = ift.from_random(dom)
res0 = np.array(func(loc.val)) f, linear = func
op = ift.JaxOperator(dom, dom, func) res0 = np.array(f(loc.val))
op = ift.JaxOperator(dom, dom, f)
np.testing.assert_allclose(res0, op(loc).val) np.testing.assert_allclose(res0, op(loc).val)
ift.extra.check_operator(op, ift.from_random(op.domain)) ift.extra.check_operator(op, ift.from_random(op.domain))
op = ift.JaxLinearOperator(dom, dom, f, np.float64)
if linear:
ift.extra.check_linear_operator(op)
else:
with pytest.raises(AssertionError):
ift.extra.check_linear_operator(op)
def test_mf_jax(): def test_mf_jax():
pytest.importorskip("jax") pytest.importorskip("jax")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment