Commit a3805bbe authored by Philipp Arras's avatar Philipp Arras
Browse files

JaxOperator: Support complex functions

parent 404027de
......@@ -364,7 +364,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin = op(lin)
myassert(oplin.jac.target is oplin0.jac.target)
rndinp = from_random(oplin.jac.target)
rndinp = from_random(oplin.jac.target, dtype=oplin.val.dtype)
assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain),
oplin0.jac.adjoint(rndinp), 1e-13, 1e-13)
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
......
......@@ -165,12 +165,12 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
class _JaxJacobian(LinearOperator):
def __init__(self, domain, target, func, adjfunc):
def __init__(self, domain, target, func, func_transposed):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._adjfunc = adjfunc
self._func_transposed = func_transposed
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
......@@ -178,6 +178,6 @@ class _JaxJacobian(LinearOperator):
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
else:
fx = self._adjfunc(x.val)[0]
return makeField(self._tgt(mode), _jax2np(fx))
return makeField(self._tgt(mode), _jax2np(fx))
fx = self._func_transposed(x.conjugate().val)[0]
return makeField(self._tgt(mode), _jax2np(fx)).conjugate()
......@@ -104,3 +104,41 @@ def test_jax_errors():
op = ift.JaxOperator(dom, mdom, lambda x: {"a": x[0]})
with pytest.raises(ValueError):
op(fld)
def test_jax_complex():
dom = ift.UnstructuredDomain(1)
a = ift.ducktape(dom, None, "a")
b = ift.ducktape(dom, None, "b")
op = a.real+1j*b.real
op1 = ift.JaxOperator(op.domain, op.target, lambda x: x["a"] + 1j*x["b"])
_op_equal(op, op1, ift.from_random(op.domain))
ift.extra.check_operator(op, ift.from_random(op.domain), ntries=10)
ift.extra.check_operator(op1, ift.from_random(op.domain), ntries=10)
op = op.imag
op1 = op1.imag
_op_equal(op, op1, ift.from_random(op.domain))
ift.extra.check_operator(op, ift.from_random(op.domain), ntries=10)
ift.extra.check_operator(op1, ift.from_random(op.domain), ntries=10)
lin = ift.Linearization.make_var(ift.from_random(op.domain))
test_vec = ift.full(op.target, 1.)
grad = op(lin).jac.adjoint(test_vec)
grad1 = op1(lin).jac.adjoint(test_vec)
ift.extra.assert_equal(grad, grad1)
ift.extra.assert_equal(grad, ift.makeField(grad.domain, {"a": 0., "b": 1.}))
def _op_equal(op0, op1, loc):
assert op0.domain is op1.domain
assert op0.target is op1.target
ift.extra.assert_allclose(op0(loc), op1(loc))
lin = ift.Linearization.make_var(loc)
res = op0(lin)
res1 = op1(lin)
ift.extra.assert_allclose(res.val, res1.val)
fld = ift.from_random(op0.domain, dtype=loc.dtype)
ift.extra.assert_allclose(res.jac(fld), res1.jac(fld))
fld = ift.from_random(op0.target, dtype=res.val.dtype)
ift.extra.assert_allclose(res.jac.adjoint(fld), res1.jac.adjoint(fld))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment