Commit 8205d1a9 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add assert_equal() and cleanup

parent 00c9005f
Pipeline #76160 passed with stages
in 13 minutes and 35 seconds
......@@ -37,6 +37,13 @@ def assert_allclose(f1, f2, atol, rtol):
assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def assert_equal(f1, f2):
if isinstance(f1, Field):
return np.testing.assert_equal(f1.val, f2.val)
for key, val in f1.items():
assert_equal(val, f2[key])
def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear):
needed_cap = op.TIMES | op.ADJOINT_TIMES
......@@ -249,7 +256,7 @@ def _linearization_value_consistency(op, loc):
def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True,
only_r_differentiable=True):
only_r_differentiable=True):
"""
Checks the Jacobian of an operator against its finite difference
approximation.
......
......@@ -270,14 +270,13 @@ class GaussianEnergy(EnergyOperator):
if sampling_dtype != _field_to_dtype(self._mean):
raise ValueError("Sampling dtype and mean not compatible")
self._icov = inverse_covariance
if inverse_covariance is None:
self._op = Squared2NormOperator(self._domain).scale(0.5)
self._met = ScalingOperator(self._domain, 1)
self._trivial_invcov = True
else:
self._op = QuadraticFormOperator(inverse_covariance)
self._met = inverse_covariance
self._trivial_invcov = False
if sampling_dtype is not None:
self._met = SamplingDtypeSetter(self._met, sampling_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment