Commit 652b282d authored by Philipp Arras's avatar Philipp Arras
Browse files

Fixups

parent 36f5e23e
Pipeline #24363 failed with stage
in 4 minutes and 9 seconds
...@@ -5,26 +5,26 @@ from ..field import Field ...@@ -5,26 +5,26 @@ from ..field import Field
__all__ = ['test_adjointness', 'test_inverse'] __all__ = ['test_adjointness', 'test_inverse']
def test_adjointness(self, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7): def test_adjointness(op, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7):
f1 = Field.from_random("normal", domain=self.domain, dtype=domain_dtype) f1 = Field.from_random("normal", domain=op.domain, dtype=domain_dtype)
f2 = Field.from_random("normal", domain=self.target, dtype=target_dtype) f2 = Field.from_random("normal", domain=op.target, dtype=target_dtype)
res1 = f1.vdot(self.adjoint_times(f2)) res1 = f1.vdot(op.adjoint_times(f2))
res2 = self.times(f1).vdot(f2) res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
# Return relative error # Return relative error
return (res1 - res2) / (res1 + res2) * 2 return (res1 - res2) / (res1 + res2) * 2
def test_inverse(self, dtype_domain=np.float64, dtype_target=np.float64, atol=0, rtol=1e-7): def test_inverse(op, dtype_domain=np.float64, dtype_target=np.float64, atol=0, rtol=1e-7):
foo = Field.from_random(domain=self.target, random_type='normal', dtype=dtype_target) foo = Field.from_random(domain=op.target, random_type='normal', dtype=dtype_target)
res = self.times(self.inverse_times(foo)).val res = op(op.inverse_times(foo)).val
ones = Field.ones(self.domain).val ones = Field.ones(op.domain).val
np.testing.assert_allclose(res, ones, atol=atol, rtol=rtol) np.testing.assert_allclose(res, ones, atol=atol, rtol=rtol)
foo = Field.from_random(domain=self.domain, random_type='normal', dtype=dtype_domain) foo = Field.from_random(domain=op.domain, random_type='normal', dtype=dtype_domain)
res = self.inverse_times(self.times(foo)).val res = op.inverse_times(op(foo)).val
ones = Field.ones(self.target).val ones = Field.ones(op.target).val
np.testing.assert_allclose(res, ones, atol=atol, rtol=rtol) np.testing.assert_allclose(res, ones, atol=atol, rtol=rtol)
# Return relative error # Return relative error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment