Commit 404027de authored by Philipp Arras's avatar Philipp Arras
Browse files

JaxOperator: Add checks for output for debugging

parent 2db6c449
......@@ -61,13 +61,38 @@ class JaxOperator(Operator):
def apply(self, x):
from ..sugar import is_linearization, makeField
from ..multi_domain import MultiDomain
self._check_input(x)
if is_linearization(x):
res, bwd = self._vjp(x.val.val)
fwd = lambda y: self._fwd(x.val.val, y)
jac = _JaxJacobian(self._domain, self._target, fwd, bwd)
return x.new(makeField(self._target, _jax2np(res)), jac)
return makeField(self._target, _jax2np(self._func(x.val)))
res = _jax2np(self._func(x.val))
if isinstance(res, dict):
if not isinstance(self._target, MultiDomain):
raise TypeError(("Jax function return a dictionary although the "
"target of the operator is a DomainTuple."))
if set(res.keys()) != set(self._target.keys()):
raise ValueError(("Keys do not match:\n"
f"Target keys: {self._target.keys()}\n"
f"Jax function returns: {res.keys()}"))
for kk in res.keys():
self._check_shape(self._target[kk].shape, res[kk].shape)
else:
if isinstance(self._target, MultiDomain):
raise TypeError(("Jax function does not return a dictionary "
"although the target of the operator is a "
"MultiDomain."))
self._check_shape(self._target.shape, res.shape)
return makeField(self._target, res)
@staticmethod
def _check_shape(shp_tgt, shp_jax):
if shp_tgt != shp_jax:
raise ValueError(("Output shapes do not match:\n"
f"Target shape is\t\t{shp_tgt}\n"
f"Jax function returns\t{shp_jax}"))
def _simplify_for_constant_input_nontrivial(self, c_inp):
func2 = lambda x: self._func({**x, **c_inp.val})
......
......@@ -86,3 +86,21 @@ def test_jax_energy(dom):
continue
pos1 = ift.from_random(e.domain)
ift.extra.assert_allclose(e0(lin).metric(pos1), e(lin).metric(pos1))
def test_jax_errors():
dom = ift.UnstructuredDomain(2)
mdom = {"a": dom}
op = ift.JaxOperator(dom, dom, lambda x: {"a": x})
fld = ift.full(dom, 0.)
with pytest.raises(TypeError):
op(fld)
op = ift.JaxOperator(dom, mdom, lambda x: x)
with pytest.raises(TypeError):
op(fld)
op = ift.JaxOperator(dom, dom, lambda x: x[0])
with pytest.raises(ValueError):
op(fld)
op = ift.JaxOperator(dom, mdom, lambda x: {"a": x[0]})
with pytest.raises(ValueError):
op(fld)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment