Commit 8dca7e53 authored by Reimar Leike's avatar Reimar Leike
Browse files

Increase Fisher test sensitivity

This was achieved by exploiting the diagonality of the Fisher metric.
In cases this is not true one needs more samples to increase sensitivity.
Tested that a factor of 2 only in the inverse covariance can now be detected.
parent 17449c6d
......@@ -171,7 +171,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
if not x.want_metric:
return res
met = 1. if self._cplx else 0.5
met = 1. if self._cplx else .5
met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
domain=self._domain)
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
......
......@@ -61,7 +61,7 @@ def test_complex2real():
assert np.all((f == op(op.adjoint_times(f))).val)
def energy_tester(pos, get_noisy_data, energy_initializer):
def energy_tester(pos, get_noisy_data, energy_initializer, assume_diagonal=None):
if isinstance(pos, ift.Field):
if np.issubdtype(pos.dtype, np.complexfloating):
op = _complex2real(pos.domain)
......@@ -89,7 +89,14 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
data = get_noisy_data(op.adjoint_times(pos))
energy = energy_initializer(data) @ op.adjoint
grad = energy(lin).gradient
if assume_diagonal:
results.append(_to_array((grad*grad.conjugate()).val))
else:
results.append(_to_array((grad*grad.s_vdot(test_vec)).val))
if assume_diagonal:
res = np.mean(np.array(results), axis=0)*_to_array(test_vec.val)
std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)*np.abs(_to_array(test_vec.val))
else:
res = np.mean(np.array(results), axis=0)
std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)
energy = energy_initializer(data) @ op.adjoint
......@@ -137,7 +144,7 @@ def test_VariableCovarianceGaussianEnergy(dtype):
def E_init(data):
adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True)
return energy.partial_insert(adder)
energy_tester(mf, get_noisy_data, E_init)
energy_tester(mf, get_noisy_data, E_init, assume_diagonal=True)
def normal(dtype, shape):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment