Commit 8205d1a9 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add assert_equal() and cleanup

parent 00c9005f
Pipeline #76160 passed with stages
in 13 minutes and 35 seconds
...@@ -37,6 +37,13 @@ def assert_allclose(f1, f2, atol, rtol): ...@@ -37,6 +37,13 @@ def assert_allclose(f1, f2, atol, rtol):
assert_allclose(val, f2[key], atol=atol, rtol=rtol) assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def assert_equal(f1, f2):
if isinstance(f1, Field):
return np.testing.assert_equal(f1.val, f2.val)
for key, val in f1.items():
assert_equal(val, f2[key])
def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol, def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear): only_r_linear):
needed_cap = op.TIMES | op.ADJOINT_TIMES needed_cap = op.TIMES | op.ADJOINT_TIMES
...@@ -249,7 +256,7 @@ def _linearization_value_consistency(op, loc): ...@@ -249,7 +256,7 @@ def _linearization_value_consistency(op, loc):
def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True, def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True,
only_r_differentiable=True): only_r_differentiable=True):
""" """
Checks the Jacobian of an operator against its finite difference Checks the Jacobian of an operator against its finite difference
approximation. approximation.
......
...@@ -270,14 +270,13 @@ class GaussianEnergy(EnergyOperator): ...@@ -270,14 +270,13 @@ class GaussianEnergy(EnergyOperator):
if sampling_dtype != _field_to_dtype(self._mean): if sampling_dtype != _field_to_dtype(self._mean):
raise ValueError("Sampling dtype and mean not compatible") raise ValueError("Sampling dtype and mean not compatible")
self._icov = inverse_covariance
if inverse_covariance is None: if inverse_covariance is None:
self._op = Squared2NormOperator(self._domain).scale(0.5) self._op = Squared2NormOperator(self._domain).scale(0.5)
self._met = ScalingOperator(self._domain, 1) self._met = ScalingOperator(self._domain, 1)
self._trivial_invcov = True
else: else:
self._op = QuadraticFormOperator(inverse_covariance) self._op = QuadraticFormOperator(inverse_covariance)
self._met = inverse_covariance self._met = inverse_covariance
self._trivial_invcov = False
if sampling_dtype is not None: if sampling_dtype is not None:
self._met = SamplingDtypeSetter(self._met, sampling_dtype) self._met = SamplingDtypeSetter(self._met, sampling_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment