Commit 241ce41f authored by Philipp Arras's avatar Philipp Arras
Browse files

Add jax likelihood operator

parent 441fffe6
......@@ -16,13 +16,15 @@
import numpy as np
from .operator import Operator
from .energy_operators import EnergyOperator, LikelihoodEnergyOperator
from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
try:
import jax
jax.config.update("jax_enable_x64", True)
__all__ = ["JaxOperator"]
__all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator"]
except ImportError:
__all__ = []
......@@ -74,6 +76,37 @@ class JaxOperator(Operator):
return None, JaxOperator(dom, self._target, func2)
class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
def __init__(self, domain, func, transformation=None, sampling_dtype=None):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._func = jax.jit(func)
self._grad = jax.jit(jax.grad(func))
self._dt = sampling_dtype
self._trafo = transformation
def get_transformation(self):
if self._trafo is None:
s = self.__name__ + " was instantiated without `transformation`"
raise RuntimeError(s)
return self._dt, self._trafo
def apply(self, x):
from ..sugar import is_linearization, makeField
from .simple_linear_operators import VdotOperator
from ..linearization import Linearization
self._check_input(x)
lin = is_linearization(x)
val = x.val.val if lin else x.val
res = makeField(self._target, _jax2np(self._func(val)))
if not lin:
return res
jac = VdotOperator(makeField(self._domain, _jax2np(self._grad(val))))
res = Linearization(res, jac)
if not x.want_metric:
return res
return res.add_metric(self.get_metric_at(x.val))
class _JaxJacobian(LinearOperator):
def __init__(self, domain, target, func, adjfunc):
......
......@@ -62,6 +62,26 @@ def test_mf_jax():
ift.extra.check_operator(op, loc)
def test_jax_energy():
if _skip:
pytest.skip()
dom = ift.UnstructuredDomain((10, 2))
e0 = ift.GaussianEnergy(domain=dom)
def func(x):
return 0.5*jnp.vdot(x, x)
e = ift.JaxLikelihoodEnergyOperator(dom, func, transformation=ift.ScalingOperator(dom, 1.))
for wm in [False, True]:
pos = ift.from_random(e.domain)
lin = ift.Linearization.make_var(pos, wm)
ift.extra.assert_allclose(e0(pos), e(pos))
ift.extra.assert_allclose(e0(lin).val, e(lin).val)
ift.extra.assert_allclose(e0(lin).gradient, e(lin).gradient)
if not wm:
continue
pos1 = ift.from_random(e.domain)
ift.extra.assert_allclose(e0(lin).metric(pos1), e(lin).metric(pos1))
def test_cf():
dom = ift.RGSpace([13, 14], distances=(0.89, 0.9))
args = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment