Commit 60037674 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add more thorough tests

parent bd796fe2
Pipeline #104722 canceled with stages
in 3 minutes and 21 seconds
......@@ -122,7 +122,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
if not lin:
return res
jac = VdotOperator(makeField(self._domain, _jax2np(self._grad(val))))
res = Linearization(res, jac)
res = x.new(res, jac)
if not x.want_metric:
return res
return res.add_metric(self.get_metric_at(x.val))
......@@ -131,8 +131,12 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
func2 = lambda x: self._func({**x, **c_inp.val})
dom = {kk: vv for kk, vv in self._domain.items()
if kk not in c_inp.keys()}
return None, JaxLikelihoodEnergyOperator(dom, func2,
self._trafo, self._dt)
_, trafo = self._trafo.simplify_for_constant_input(c_inp)
if isinstance(self._dt, dict):
dt = {kk: self._dt[kk] for kk in dom.keys()}
else:
dt = self._dt
return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt)
class _JaxJacobian(LinearOperator):
......
......@@ -78,7 +78,9 @@ def test_jax_energy(dom):
return 0.5*res
e = ift.JaxLikelihoodEnergyOperator(dom,
funcmf if isinstance(dom, ift.MultiDomain) else func,
transformation=ift.ScalingOperator(dom, 1.))
transformation=ift.ScalingOperator(dom, 1.),
sampling_dtype=np.float64)
ift.extra.check_operator(e, ift.from_random(e.domain))
for wm in [False, True]:
pos = ift.from_random(e.domain)
lin = ift.Linearization.make_var(pos, wm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment