Skip to content
Snippets Groups Projects
Commit 02d4686c authored by Philipp Arras's avatar Philipp Arras
Browse files

Change inputs and outputs

parent feb207b2
No related branches found
No related tags found
1 merge request!216Add adjointness and inverse test
Pipeline #
...@@ -164,16 +164,19 @@ class LinearOperator(with_metaclass( ...@@ -164,16 +164,19 @@ class LinearOperator(with_metaclass(
raise ValueError("The operator's and and field's domains " raise ValueError("The operator's and and field's domains "
"don't match.") "don't match.")
def test_adjointness(self, domain_dtype=np.float64, target_dtype=np.float64): def test_adjointness(self, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7):
f1 = Field.from_random("normal", domain=self.domain, f1 = Field.from_random("normal", domain=self.domain,
dtype=domain_dtype) dtype=domain_dtype)
f2 = Field.from_random("normal", domain=self.target, f2 = Field.from_random("normal", domain=self.target,
dtype=target_dtype) dtype=target_dtype)
res1 = f1.vdot(self.adjoint_times(f2)) res1 = f1.vdot(self.adjoint_times(f2))
res2 = self.times(f1).vdot(f2) res2 = self.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
# Return relative error
return (res1 - res2) / (res1 + res2) * 2 return (res1 - res2) / (res1 + res2) * 2
def test_inverse(self, dtype_domain=np.float64, dtype_target=np.float64): def test_inverse(self, dtype_domain=np.float64, dtype_target=np.float64, atol=0, rtol=1e-7):
foo = Field.from_random( foo = Field.from_random(
domain=self.target, random_type='normal', dtype=dtype_target) domain=self.target, random_type='normal', dtype=dtype_target)
bar = self.times(self.inverse_times(foo)).val bar = self.times(self.inverse_times(foo)).val
...@@ -181,7 +184,11 @@ class LinearOperator(with_metaclass( ...@@ -181,7 +184,11 @@ class LinearOperator(with_metaclass(
foo = Field.from_random( foo = Field.from_random(
domain=self.domain, random_type='normal', dtype=dtype_domain) domain=self.domain, random_type='normal', dtype=dtype_domain)
bar = self.inverse_times(self.times(foo)).val
np.testing.assert_allclose(bar, Field.ones(self.domain).val)
return True res1 = self.inverse_times(self.times(foo)).val
res2 = Field.ones(self.domain).val
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
# Return relative error
return (res1 - res2) / (res1 + res2) * 2
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment