Commit 44ee2914 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'jax_linear_operator' into 'NIFTy_8'

Jax linear operator

See merge request !673
parents 80fd5456 1b646049
Pipeline #107709 passed with stages
in 35 minutes and 51 seconds
......@@ -14,17 +14,18 @@
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
from types import SimpleNamespace
import numpy as np
from .operator import Operator
from .energy_operators import EnergyOperator, LikelihoodEnergyOperator
from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
from .energy_operators import LikelihoodEnergyOperator
from .linear_operator import LinearOperator
from .operator import Operator
try:
import jax
jax.config.update("jax_enable_x64", True)
__all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator"]
__all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator", "JaxLinearOperator"]
except ImportError:
__all__ = []
......@@ -48,7 +49,7 @@ class JaxOperator(Operator):
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the
`MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target.
"""
def __init__(self, domain, target, func):
......@@ -60,13 +61,13 @@ class JaxOperator(Operator):
self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1])
def apply(self, x):
from ..sugar import is_linearization, makeField
from ..multi_domain import MultiDomain
from ..sugar import is_linearization, makeField
self._check_input(x)
if is_linearization(x):
res, bwd = self._vjp(x.val.val)
fwd = lambda y: self._fwd(x.val.val, y)
jac = _JaxJacobian(self._domain, self._target, fwd, bwd)
jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=bwd)
return x.new(makeField(self._target, _jax2np(res)), jac)
res = _jax2np(self._func(x.val))
if isinstance(res, dict):
......@@ -101,6 +102,70 @@ class JaxOperator(Operator):
return None, JaxOperator(dom, self._target, func2)
class JaxLinearOperator(LinearOperator):
"""Wrap a jax function as nifty linear operator.
Parameters
----------
domain : DomainTuple or MultiDomain
Domain of the operator.
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target.
func_T : callable
The jax function that implements the transposed action of the operator.
If None, jax computes the adjoint. Note that this is *not* the adjoint
action. Default: None.
domain_dtype:
Needs to be set if `func_transposed` is None. Otherwise it does not have
an effect. Dtype of the domain. If `domain` is a `MultiDomain`,
`domain_dtype` is supposed to be a dictionary. Default: None.
Note
----
It is the user's responsibility that func is actually a linear function. The
user can double check this with the help of
`nifty8.extra.check_linear_operator`.
"""
def __init__(self, domain, target, func, domain_dtype=None, func_T=None):
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if domain_dtype is not None and func_T is None:
if isinstance(domain, DomainTuple):
inp = SimpleNamespace(shape=domain.shape, dtype=domain_dtype)
else:
inp = {kk: SimpleNameSpace(shape=domain[kk].shape, dtype=domain_dtype[kk])
for kk in domain.keys()}
func_T = jax.jit(jax.linear_transpose(func, inp))
elif domain_dtype is None and func_T is not None:
pass
else:
raise ValueError("Either domain_dtype or func_T have to be not None.")
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._func_T = func_T
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
from ..sugar import makeField
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
return makeField(self._target, _jax2np(fx))
fx = self._func_T(x.conjugate().val)[0]
return makeField(self._domain, _jax2np(fx)).conjugate()
class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
"""Wrap a jax function as nifty likelihood energy operator.
......@@ -112,7 +177,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the
`MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target. It needs to map to a scalar.
transformation : Operator, optional
......@@ -137,9 +202,9 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
return self._dt, self._trafo
def apply(self, x):
from ..linearization import Linearization
from ..sugar import is_linearization, makeField
from .simple_linear_operators import VdotOperator
from ..linearization import Linearization
self._check_input(x)
lin = is_linearization(x)
val = x.val.val if lin else x.val
......@@ -162,22 +227,3 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
else:
dt = self._dt
return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt)
class _JaxJacobian(LinearOperator):
def __init__(self, domain, target, func, func_transposed):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._func_transposed = func_transposed
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
from ..sugar import makeField
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
return makeField(self._tgt(mode), _jax2np(fx))
fx = self._func_transposed(x.conjugate().val)[0]
return makeField(self._tgt(mode), _jax2np(fx)).conjugate()
......@@ -16,7 +16,6 @@
import nifty8 as ift
import numpy as np
import matplotlib.pyplot as plt
import pytest
try:
import jax.numpy as jnp
......@@ -29,16 +28,26 @@ pmp = pytest.mark.parametrize
@pmp("dom", [ift.RGSpace((10, 8)), (ift.RGSpace(10), ift.RGSpace(8))])
@pmp("func", [lambda x: x, lambda x: x**2, lambda x: x*x, lambda x: x*x[0, 0],
lambda x: jnp.sin(x), lambda x: x*x.sum()])
@pmp("func", [(lambda x: x, True), (lambda x: x**2, False), (lambda x: x*x, False),
(lambda x: x*x[0, 0], False), (lambda x: x+x[0, 0], True),
(lambda x: jnp.sin(x), False), (lambda x: x*x.sum(), False),
(lambda x: x+x.sum(), True)])
def test_jax(dom, func):
pytest.importorskip("jax")
loc = ift.from_random(dom)
res0 = np.array(func(loc.val))
op = ift.JaxOperator(dom, dom, func)
f, linear = func
res0 = np.array(f(loc.val))
op = ift.JaxOperator(dom, dom, f)
np.testing.assert_allclose(res0, op(loc).val)
ift.extra.check_operator(op, ift.from_random(op.domain))
op = ift.JaxLinearOperator(dom, dom, f, np.float64)
if linear:
ift.extra.check_linear_operator(op)
else:
with pytest.raises(Exception):
ift.extra.check_linear_operator(op)
def test_mf_jax():
pytest.importorskip("jax")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment