Commit b99e06c0 authored by Reimar Leike's avatar Reimar Leike
Browse files

Adjusted fisher test to always make Fisher matrices reaL, code for GaussianEnergy was reverted

parent d799dd37
Pipeline #76780 passed with stages
in 12 minutes and 9 seconds
......@@ -244,18 +244,7 @@ class GaussianEnergy(EnergyOperator):
self._met = inverse_covariance
if sampling_dtype is not None:
self._met = SamplingDtypeSetter(self._met, sampling_dtype)
if isinstance(sampling_dtype, dict):
from .sandwich_operator import SandwichOperator
scale = {k:np.sqrt(2.) if np.issubdtype(v, np.complexfloating)
else 1. for k,v in sampling_dtype.items()}
scale = _build_MultiScalingOperator(self._domain, scale)
self._met = SandwichOperator.make(scale, self._met)
else:
if np.issubdtype(sampling_dtype, np.complexfloating):
from .sandwich_operator import SandwichOperator
scale = ScalingOperator(self._met.domain,np.sqrt(2))
self._met = SandwichOperator.make(scale, self._met)
def _checkEquivalence(self, newdom):
newdom = makeDomain(newdom)
if self._domain is None:
......
......@@ -23,7 +23,6 @@ import nifty6 as ift
from ..common import list2fixture, setup_function, teardown_function
spaces = [ift.GLSpace(5),
ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}),
(ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))]
pmp = pytest.mark.parametrize
field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
......@@ -38,7 +37,34 @@ def _to_array(d):
assert isinstance(d, dict)
return np.concatenate(list(d.values()))
def _complex2real(sp):
tup = tuple([d for d in sp])
rsp = ift.DomainTuple.make((ift.UnstructuredDomain(2),) + tup)
rl = ift.DomainTupleFieldInserter(rsp, 0, (0,))
im = ift.DomainTupleFieldInserter(rsp, 0, (1,))
x = ift.ScalingOperator(sp, 1)
return rl(x.real)+im(x.imag)
def test_complex2real():
sp = ift.UnstructuredDomain(3)
op = _complex2real(ift.makeDomain(sp))
f = ift.from_random(op.domain, 'normal', dtype=np.complex128)
assert np.all((f == op.adjoint_times(op(f))).val)
assert op(f).dtype == np.float64
f = ift.from_random(op.target, 'normal')
assert np.all((f == op(op.adjoint_times(f))).val)
def energy_tester_complex(pos, get_noisy_data, energy_initializer):
op = _complex2real(pos.domain)
npos = op(pos)
nget_noisy_data = lambda mean : get_noisy_data(op.adjoint_times(mean))
nenergy_initializer = lambda mean : energy_initializer(mean) @ op.adjoint
energy_tester(npos, nget_noisy_data, nenergy_initializer)
def energy_tester(pos, get_noisy_data, energy_initializer):
if np.issubdtype(pos.dtype, np.complexfloating):
energy_tester_complex(pos, get_noisy_data, energy_initializer)
return
domain = pos.domain
test_vec = ift.from_random(domain, 'normal')
results = []
......@@ -48,6 +74,8 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
energy = energy_initializer(data)
grad = energy(lin).jac.adjoint(ift.full(energy.target, 1.))
results.append(_to_array((grad*grad.s_vdot(test_vec)).val))
print(energy)
print(grad)
res = np.mean(np.array(results), axis=0)
std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)
energy = energy_initializer(data)
......@@ -57,6 +85,7 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
def test_GaussianEnergy(field):
dtype = field.dtype
icov = ift.from_random(field.domain, 'normal')**2
icov = ift.makeOp(icov)
get_noisy_data = lambda mean : mean + icov.draw_sample_with_dtype(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment