Commit bd796fe2 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add MultiDomain tests

parent 05e3cb74
Pipeline #104721 canceled with stages
in 12 minutes and 24 seconds
......@@ -127,6 +127,13 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
return res
return res.add_metric(self.get_metric_at(x.val))
def _simplify_for_constant_input_nontrivial(self, c_inp):
func2 = lambda x: self._func({**x, **c_inp.val})
dom = {kk: vv for kk, vv in self._domain.items()
if kk not in c_inp.keys()}
return None, JaxLikelihoodEnergyOperator(dom, func2,
self._trafo, self._dt)
class _JaxJacobian(LinearOperator):
def __init__(self, domain, target, func, adjfunc):
......
......@@ -62,14 +62,23 @@ def test_mf_jax():
ift.extra.check_operator(op, loc)
def test_jax_energy():
@pmp("dom", [ift.RGSpace((10, 8)),
{"a": ift.RGSpace(10), "b": ift.UnstructuredDomain(2)}])
def test_jax_energy(dom):
if _skip:
pytest.skip()
dom = ift.UnstructuredDomain((10, 2))
dom = ift.makeDomain(dom)
e0 = ift.GaussianEnergy(domain=dom)
def func(x):
return 0.5*jnp.vdot(x, x)
e = ift.JaxLikelihoodEnergyOperator(dom, func, transformation=ift.ScalingOperator(dom, 1.))
def funcmf(x):
res = 0
for kk, vv in x.items():
res += jnp.vdot(vv, vv)
return 0.5*res
e = ift.JaxLikelihoodEnergyOperator(dom,
funcmf if isinstance(dom, ift.MultiDomain) else func,
transformation=ift.ScalingOperator(dom, 1.))
for wm in [False, True]:
pos = ift.from_random(e.domain)
lin = ift.Linearization.make_var(pos, wm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment