Skip to content
Snippets Groups Projects
Commit 1c1a6b51 authored by Philipp Arras's avatar Philipp Arras
Browse files

MPI fixups

parent 820ddd40
Branches
Tags
1 merge request!339Linearizationtests
Pipeline #52936 passed
......@@ -25,7 +25,7 @@ pmp = pytest.mark.parametrize
def _lin2grad(lin):
return lin.jac(ift.full(lin.domain, 1.)).local_data
return lin.jac(ift.full(lin.domain, 1.)).to_global_data()
def jt(lin, check):
......@@ -36,7 +36,7 @@ def test_special_gradients():
dom = ift.UnstructuredDomain((1,))
f = ift.full(dom, 2.4)
var = ift.Linearization.make_var(f)
s = f.local_data
s = f.to_global_data()
jt(var.clip(0, 10), np.ones_like(s))
jt(var.clip(-1, 0), np.zeros_like(s))
......@@ -62,8 +62,8 @@ def test_actual_gradients(f):
eps = 1e-8
var0 = ift.Linearization.make_var(fld)
var1 = ift.Linearization.make_var(fld + eps)
f0 = getattr(var0, f)().val.local_data
f1 = getattr(var1, f)().val.local_data
f0 = getattr(var0, f)().val.to_global_data()
f1 = getattr(var1, f)().val.to_global_data()
df0 = (f1 - f0)/eps
df1 = _lin2grad(getattr(var0, f)())
assert_allclose(df0, df1, rtol=100*eps)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment