Commit 441fffe6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Compile forward derivative

parent 820803ac
Pipeline #104399 passed with stages
in 28 minutes and 4 seconds
......@@ -55,13 +55,14 @@ class JaxOperator(Operator):
self._target = makeDomain(target)
self._func = jax.jit(func)
self._vjp = jax.jit(lambda x: jax.vjp(func, x))
self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1])
def apply(self, x):
from ..sugar import is_linearization, makeField
self._check_input(x)
if is_linearization(x):
res, bwd = self._vjp(x.val.val)
fwd = lambda y: jax.jvp(self._func, (x.val.val,), (y,))[1]
fwd = lambda y: self._fwd(x.val.val, y)
jac = _JaxJacobian(self._domain, self._target, fwd, bwd)
return x.new(makeField(self._target, _jax2np(res)), jac)
return makeField(self._target, _jax2np(self._func(x.val)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment