Commit 44ee2914 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'jax_linear_operator' into 'NIFTy_8'

Jax linear operator

See merge request !673
parents 80fd5456 1b646049
Pipeline #107709 passed with stages
in 35 minutes and 51 seconds
...@@ -14,17 +14,18 @@ ...@@ -14,17 +14,18 @@
# Copyright(C) 2021 Max-Planck-Society # Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras # Author: Philipp Arras
from types import SimpleNamespace
import numpy as np import numpy as np
from .operator import Operator
from .energy_operators import EnergyOperator, LikelihoodEnergyOperator
from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
from .energy_operators import LikelihoodEnergyOperator
from .linear_operator import LinearOperator
from .operator import Operator
try: try:
import jax import jax
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
__all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator"] __all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator", "JaxLinearOperator"]
except ImportError: except ImportError:
__all__ = [] __all__ = []
...@@ -48,7 +49,7 @@ class JaxOperator(Operator): ...@@ -48,7 +49,7 @@ class JaxOperator(Operator):
func : callable func : callable
The jax function that is evaluated by the operator. It has to be The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the `MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target. target.
""" """
def __init__(self, domain, target, func): def __init__(self, domain, target, func):
...@@ -60,13 +61,13 @@ class JaxOperator(Operator): ...@@ -60,13 +61,13 @@ class JaxOperator(Operator):
self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1]) self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1])
def apply(self, x): def apply(self, x):
from ..sugar import is_linearization, makeField
from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
from ..sugar import is_linearization, makeField
self._check_input(x) self._check_input(x)
if is_linearization(x): if is_linearization(x):
res, bwd = self._vjp(x.val.val) res, bwd = self._vjp(x.val.val)
fwd = lambda y: self._fwd(x.val.val, y) fwd = lambda y: self._fwd(x.val.val, y)
jac = _JaxJacobian(self._domain, self._target, fwd, bwd) jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=bwd)
return x.new(makeField(self._target, _jax2np(res)), jac) return x.new(makeField(self._target, _jax2np(res)), jac)
res = _jax2np(self._func(x.val)) res = _jax2np(self._func(x.val))
if isinstance(res, dict): if isinstance(res, dict):
...@@ -101,6 +102,70 @@ class JaxOperator(Operator): ...@@ -101,6 +102,70 @@ class JaxOperator(Operator):
return None, JaxOperator(dom, self._target, func2) return None, JaxOperator(dom, self._target, func2)
class JaxLinearOperator(LinearOperator):
"""Wrap a jax function as nifty linear operator.
Parameters
----------
domain : DomainTuple or MultiDomain
Domain of the operator.
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target.
func_T : callable
The jax function that implements the transposed action of the operator.
If None, jax computes the adjoint. Note that this is *not* the adjoint
action. Default: None.
domain_dtype:
Needs to be set if `func_transposed` is None. Otherwise it does not have
an effect. Dtype of the domain. If `domain` is a `MultiDomain`,
`domain_dtype` is supposed to be a dictionary. Default: None.
Note
----
It is the user's responsibility that func is actually a linear function. The
user can double check this with the help of
`nifty8.extra.check_linear_operator`.
"""
def __init__(self, domain, target, func, domain_dtype=None, func_T=None):
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if domain_dtype is not None and func_T is None:
if isinstance(domain, DomainTuple):
inp = SimpleNamespace(shape=domain.shape, dtype=domain_dtype)
else:
inp = {kk: SimpleNameSpace(shape=domain[kk].shape, dtype=domain_dtype[kk])
for kk in domain.keys()}
func_T = jax.jit(jax.linear_transpose(func, inp))
elif domain_dtype is None and func_T is not None:
pass
else:
raise ValueError("Either domain_dtype or func_T have to be not None.")
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._func_T = func_T
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
from ..sugar import makeField
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
return makeField(self._target, _jax2np(fx))
fx = self._func_T(x.conjugate().val)[0]
return makeField(self._domain, _jax2np(fx)).conjugate()
class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator): class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
"""Wrap a jax function as nifty likelihood energy operator. """Wrap a jax function as nifty likelihood energy operator.
...@@ -112,7 +177,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator): ...@@ -112,7 +177,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
func : callable func : callable
The jax function that is evaluated by the operator. It has to be The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the `MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target. It needs to map to a scalar. target. It needs to map to a scalar.
transformation : Operator, optional transformation : Operator, optional
...@@ -137,9 +202,9 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator): ...@@ -137,9 +202,9 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
return self._dt, self._trafo return self._dt, self._trafo
def apply(self, x): def apply(self, x):
from ..linearization import Linearization
from ..sugar import is_linearization, makeField from ..sugar import is_linearization, makeField
from .simple_linear_operators import VdotOperator from .simple_linear_operators import VdotOperator
from ..linearization import Linearization
self._check_input(x) self._check_input(x)
lin = is_linearization(x) lin = is_linearization(x)
val = x.val.val if lin else x.val val = x.val.val if lin else x.val
...@@ -162,22 +227,3 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator): ...@@ -162,22 +227,3 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
else: else:
dt = self._dt dt = self._dt
return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt) return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt)
class _JaxJacobian(LinearOperator):
def __init__(self, domain, target, func, func_transposed):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._func_transposed = func_transposed
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
from ..sugar import makeField
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
return makeField(self._tgt(mode), _jax2np(fx))
fx = self._func_transposed(x.conjugate().val)[0]
return makeField(self._tgt(mode), _jax2np(fx)).conjugate()
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import nifty8 as ift import nifty8 as ift
import numpy as np import numpy as np
import matplotlib.pyplot as plt
import pytest import pytest
try: try:
import jax.numpy as jnp import jax.numpy as jnp
...@@ -29,16 +28,26 @@ pmp = pytest.mark.parametrize ...@@ -29,16 +28,26 @@ pmp = pytest.mark.parametrize
@pmp("dom", [ift.RGSpace((10, 8)), (ift.RGSpace(10), ift.RGSpace(8))]) @pmp("dom", [ift.RGSpace((10, 8)), (ift.RGSpace(10), ift.RGSpace(8))])
@pmp("func", [lambda x: x, lambda x: x**2, lambda x: x*x, lambda x: x*x[0, 0], @pmp("func", [(lambda x: x, True), (lambda x: x**2, False), (lambda x: x*x, False),
lambda x: jnp.sin(x), lambda x: x*x.sum()]) (lambda x: x*x[0, 0], False), (lambda x: x+x[0, 0], True),
(lambda x: jnp.sin(x), False), (lambda x: x*x.sum(), False),
(lambda x: x+x.sum(), True)])
def test_jax(dom, func): def test_jax(dom, func):
pytest.importorskip("jax") pytest.importorskip("jax")
loc = ift.from_random(dom) loc = ift.from_random(dom)
res0 = np.array(func(loc.val)) f, linear = func
op = ift.JaxOperator(dom, dom, func) res0 = np.array(f(loc.val))
op = ift.JaxOperator(dom, dom, f)
np.testing.assert_allclose(res0, op(loc).val) np.testing.assert_allclose(res0, op(loc).val)
ift.extra.check_operator(op, ift.from_random(op.domain)) ift.extra.check_operator(op, ift.from_random(op.domain))
op = ift.JaxLinearOperator(dom, dom, f, np.float64)
if linear:
ift.extra.check_linear_operator(op)
else:
with pytest.raises(Exception):
ift.extra.check_linear_operator(op)
def test_mf_jax(): def test_mf_jax():
pytest.importorskip("jax") pytest.importorskip("jax")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment