diff --git a/test/test_linearization.py b/test/test_linearization.py index a1b083e3fcf4b2c9f757f46237e6a6904ab997b1..f9973900ec9055db59fcfeb20b3935fbbeaf1890 100644 --- a/test/test_linearization.py +++ b/test/test_linearization.py @@ -25,7 +25,7 @@ pmp = pytest.mark.parametrize def _lin2grad(lin): - return lin.jac(ift.full(lin.domain, 1.)).local_data + return lin.jac(ift.full(lin.domain, 1.)).to_global_data() def jt(lin, check): @@ -36,7 +36,7 @@ def test_special_gradients(): dom = ift.UnstructuredDomain((1,)) f = ift.full(dom, 2.4) var = ift.Linearization.make_var(f) - s = f.local_data + s = f.to_global_data() jt(var.clip(0, 10), np.ones_like(s)) jt(var.clip(-1, 0), np.zeros_like(s)) @@ -62,8 +62,8 @@ def test_actual_gradients(f): eps = 1e-8 var0 = ift.Linearization.make_var(fld) var1 = ift.Linearization.make_var(fld + eps) - f0 = getattr(var0, f)().val.local_data - f1 = getattr(var1, f)().val.local_data + f0 = getattr(var0, f)().val.to_global_data() + f1 = getattr(var1, f)().val.to_global_data() df0 = (f1 - f0)/eps df1 = _lin2grad(getattr(var0, f)()) assert_allclose(df0, df1, rtol=100*eps)