Commit f557057f authored by Philipp Arras's avatar Philipp Arras
Browse files

Use jax.value_and_grad for energy operator

parent 272829d1
Pipeline #104760 passed with stages
in 16 minutes and 18 seconds
......@@ -101,7 +101,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._func = jax.jit(func)
self._grad = jax.jit(jax.grad(func))
self._val_and_grad = jax.jit(jax.value_and_grad(func))
self._dt = sampling_dtype
self._trafo = transformation
......@@ -118,11 +118,11 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
self._check_input(x)
lin = is_linearization(x)
val = x.val.val if lin else x.val
res = makeField(self._target, _jax2np(self._func(val)))
if not lin:
return res
jac = VdotOperator(makeField(self._domain, _jax2np(self._grad(val))))
res = x.new(res, jac)
return makeField(self._target, _jax2np(self._func(val)))
res, grad = self._val_and_grad(val)
jac = VdotOperator(makeField(self._domain, _jax2np(grad)))
res = x.new(makeField(self._target, _jax2np(res)), jac)
if not x.want_metric:
return res
return res.add_metric(self.get_metric_at(x.val))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment