Commit 4dca30b7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'metric_tests' into 'NIFTy_7'

Fix VariableCovarianceGaussianEnergy

See merge request ift/nifty!548
parents b5e93218 8d513a11
......@@ -171,7 +171,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
if not x.want_metric:
return res
met = 1. if self._cplx else 0.5
met = 1. if self._cplx else .5
met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
domain=self._domain)
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -61,7 +61,7 @@ def test_complex2real():
assert np.all((f == op(op.adjoint_times(f))).val)
def energy_tester(pos, get_noisy_data, energy_initializer):
def energy_tester(pos, get_noisy_data, energy_initializer, assume_diagonal=False):
if isinstance(pos, ift.Field):
if np.issubdtype(pos.dtype, np.complexfloating):
op = _complex2real(pos.domain)
......@@ -69,7 +69,7 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
op = ift.ScalingOperator(pos.domain, 1.)
else:
ops = []
for k,dom in pos.domain.items():
for k, dom in pos.domain.items():
if np.issubdtype(pos[k].dtype, np.complexfloating):
ops.append(_complex2real(dom).ducktape(k).ducktape_left(k))
else:
......@@ -79,29 +79,34 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
from nifty7.operator_spectrum import _DomRemover
flattener = _DomRemover(realizer.target)
op = flattener @ realizer
pos = op(pos)
npos = op(pos)
nget_noisy_data = lambda mean: get_noisy_data(op.adjoint_times(mean))
nenergy_initializer = lambda mean: energy_initializer(mean) @ op.adjoint
_actual_energy_tester(npos, nget_noisy_data, nenergy_initializer)
def _actual_energy_tester(pos, get_noisy_data, energy_initializer):
domain = pos.domain
test_vec = ift.from_random(domain, 'normal')
results = []
lin = ift.Linearization.make_var(pos)
for i in range(Nsamp):
data = get_noisy_data(pos)
energy = energy_initializer(data)
grad = energy(lin).jac.adjoint(ift.full(energy.target, 1.))
results.append(_to_array((grad*grad.s_vdot(test_vec)).val))
res = np.mean(np.array(results), axis=0)
std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)
energy = energy_initializer(data)
data = get_noisy_data(op.adjoint_times(pos))
energy = energy_initializer(data) @ op.adjoint
grad = energy(lin).gradient
if assume_diagonal:
results.append(_to_array((grad*grad.conjugate()).val))
else:
results.append(_to_array((grad*grad.s_vdot(test_vec)).val))
if assume_diagonal:
res = np.mean(np.array(results), axis=0)*_to_array(test_vec.val)
std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)*np.abs(_to_array(test_vec.val))
else:
res = np.mean(np.array(results), axis=0)
std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)
energy = energy_initializer(data) @ op.adjoint
lin = ift.Linearization.make_var(pos, want_metric=True)
res2 = _to_array(energy(lin).metric(test_vec).val)
np.testing.assert_allclose(res/std, res2/std, atol=6)
np.testing.assert_allclose(res/std, res2/std, atol=5)
# Test whether one would detect a factor of 2 in the Fisher metric
for factor in [0.5, 2]:
with pytest.raises(AssertionError):
np.testing.assert_allclose(res/std, factor*res2/std, atol=5)
def test_GaussianEnergy(field):
......@@ -125,17 +130,22 @@ def test_PoissonEnergy(field):
E_init = lambda data: ift.PoissonianEnergy(data)
energy_tester(lam, get_noisy_data, E_init)
def test_VariableCovarianceGaussianEnergy(dtype):
dom = ift.UnstructuredDomain(3)
res = ift.from_random(dom, 'normal', dtype=dtype)
ivar = ift.from_random(dom, 'normal')**2+4.
mf = ift.MultiField.from_dict({'res':res, 'ivar':ivar})
mf = ift.MultiField.from_dict({'res': res, 'ivar': ivar})
energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype)
def get_noisy_data(mean):
samp = ift.from_random(dom, 'normal', dtype)
samp = samp/mean['ivar'].sqrt()
return samp + mean['res']
def E_init(data):
adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True)
adder = ift.Adder(ift.MultiField.from_dict({'res': data}), neg=True)
return energy.partial_insert(adder)
energy_tester(mf, get_noisy_data, E_init)
energy_tester(mf, get_noisy_data, E_init, assume_diagonal=True)
def normal(dtype, shape):
return ift.random.Random.normal(dtype, shape)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment