Commit 1aae596d authored by Philipp Arras's avatar Philipp Arras
Browse files

Add wrapper for jax

parent 94171dce
......@@ -12,9 +12,7 @@ RUN apt-get update && apt-get install -y \
# Optional NIFTy dependencies
python3-mpi4py python3-matplotlib \
# more optional NIFTy dependencies
&& pip3 install ducc0 \
&& pip3 install finufft \
&& pip3 install jupyter \
&& pip3 install ducc0 finufft jupyter jax jaxlib \
&& rm -rf /var/lib/apt/lists/*
# Set matplotlib backend
......
......@@ -53,6 +53,7 @@ Optional dependencies:
harmonic transforms, and radio interferometry gridding support
- [mpi4py](https://mpi4py.scipy.org) (for MPI-parallel execution)
- [matplotlib](https://matplotlib.org/) (for field plotting)
- [jax](https://github.com/google/jax) (for implementing operators with jax)
### Sources
......@@ -79,6 +80,8 @@ The DUCC0 package is installed via:
pip3 install ducc0
For installing jax refer to [google/jax:README#Installation](https://github.com/google/jax#installation).
If this library is present, NIFTy will detect it automatically and prefer
`ducc0.fft` over SciPy's FFT. The underlying code is actually the same, but
DUCC's FFT is compiled with optimizations for the host CPU and can provide
......
......@@ -54,6 +54,7 @@ from .operators.energy_operators import (
from .operators.convolution_operators import FuncConvolutionOperator
from .operators.normal_operators import NormalTransform, LognormalTransform
from .operators.multifield2vector import Multifield2Vector
from .operators.jax_operator import *
from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator, approximation2endo
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import numpy as np
from .operator import Operator
from .linear_operator import LinearOperator
try:
import jax
jax.config.update("jax_enable_x64", True)
__all__ = ["JaxOperator"]
except ImportError:
__all__ = []
def _jax2np(obj):
if isinstance(obj, dict):
return {kk: np.array(vv) for kk, vv in obj.items()}
return np.array(obj)
class JaxOperator(Operator):
"""Wrap a jax function as nifty operator.
Parameters
----------
domain : DomainTuple or MultiDomain
Domain of the operator.
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the
target.
"""
def __init__(self, domain, target, func):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = jax.jit(func)
def apply(self, x):
from ..sugar import is_linearization, makeField
self._check_input(x)
if is_linearization(x):
res, bwd = jax.vjp(self._func, x.val.val)
fwd = lambda y: jax.jvp(self._func, (x.val.val,), (y,))[1]
jac = _JaxJacobian(self._domain, self._target, fwd, bwd)
return x.new(makeField(self._target, _jax2np(res)), jac)
return makeField(self._target, _jax2np(self._func(x.val)))
def _simplify_for_constant_input_nontrivial(self, c_inp):
func2 = lambda x: self._func({**x, **c_inp.val})
dom = {kk: vv for kk, vv in self._domain.items()
if kk not in c_inp.keys()}
return None, JaxOperator(dom, self._target, func2)
class _JaxJacobian(LinearOperator):
def __init__(self, domain, target, func, adjfunc):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._adjfunc = adjfunc
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
from ..sugar import makeField
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
else:
fx = self._adjfunc(x.val)[0]
return makeField(self._tgt(mode), _jax2np(fx))
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import nifty7 as ift
import numpy as np
import matplotlib.pyplot as plt
import pytest
try:
import jax.numpy as jnp
_skip = False
except ImportError:
import numpy as np
_skip = True
from ..common import setup_function, teardown_function
pmp = pytest.mark.parametrize
@pmp("dom", [ift.RGSpace((10, 8)), (ift.RGSpace(10), ift.RGSpace(8))])
@pmp("func", [lambda x: x, lambda x: x**2, lambda x: x*x, lambda x: x*x[0, 0],
lambda x: jnp.sin(x), lambda x: x*x.sum()])
def test_jax(dom, func):
if _skip:
pytest.skip()
loc = ift.from_random(dom)
res0 = np.array(func(loc.val))
op = ift.JaxOperator(dom, dom, func)
np.testing.assert_allclose(res0, op(loc).val)
ift.extra.check_operator(op, ift.from_random(op.domain))
def test_mf_jax():
if _skip:
pytest.skip()
dom = ift.makeDomain({"a": ift.RGSpace(10), "b": ift.UnstructuredDomain(2)})
func = lambda x: x["a"]*x["b"][0]
op = ift.JaxOperator(dom, dom["a"], func)
loc = ift.from_random(op.domain)
np.testing.assert_allclose(np.array(func(loc.val)), op(loc).val)
ift.extra.check_operator(op, loc)
func = lambda x: {"a": jnp.full(dom["a"].shape, 2.)*x[0]*x[1], "b": jnp.full(dom["b"].shape, 1.)*jnp.exp(x[0])}
op = ift.JaxOperator(dom["b"], dom, func)
loc = ift.from_random(op.domain)
for kk in dom.keys():
np.testing.assert_allclose(np.array(func(loc.val)[kk]), op(loc)[kk].val)
ift.extra.check_operator(op, loc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment