Skip to content
Snippets Groups Projects
Commit de31a28f authored by Philipp Haim's avatar Philipp Haim
Browse files

More tests for correlated field

parent d2c454b4
No related branches found
No related tags found
1 merge request!399Fix new finalize method
......@@ -100,11 +100,23 @@ def _log_vol(power_space):
return logk_lengths[1:] - logk_lengths[:-1]
def _structured_spaces(domain):
if isinstance(domain[0], UnstructuredDomain):
return np.arange(1, len(domain))
return np.arange(len(domain))
def _total_fluctuation_realized(samples):
spaces = _structured_spaces(samples[0].domain)
co = ContractionOperator(samples[0].domain, spaces)
size = co.domain.size/co.target.size
res = 0.
for s in samples:
res = res + (s - s.mean())**2
return np.sqrt((res/len(samples)).mean())
res = res + (s - co.adjoint(co(s)/size))**2
res = res.mean(spaces)/len(samples)
if np.isscalar(res):
return np.sqrt(res)
return np.sqrt(res.val)
class _LognormalMomentMatching(Operator):
......@@ -576,10 +588,13 @@ class CorrelatedFieldMaker:
@staticmethod
def offset_amplitude_realized(samples):
spaces = _structured_spaces(samples[0].domain)
res = 0.
for s in samples:
res = res + s.mean()**2
return np.sqrt(res/len(samples))
res = res + s.mean(spaces)**2
if np.isscalar(res):
return np.sqrt(res/len(samples))
return np.sqrt(res.val/len(samples))
@staticmethod
def total_fluctuation_realized(samples):
......@@ -589,36 +604,46 @@ class CorrelatedFieldMaker:
def slice_fluctuation_realized(samples, space):
"""Computes slice fluctuations from collection of field (defined in signal
space) realizations."""
ldom = len(samples[0].domain)
if space >= ldom:
spaces = _structured_spaces(samples[0].domain)
if space >= len(spaces):
raise ValueError("invalid space specified; got {!r}".format(space))
if ldom == 1:
if len(spaces) == 1:
return _total_fluctuation_realized(samples)
space = space + spaces[0]
res1, res2 = 0., 0.
for s in samples:
res1 = res1 + s**2
res2 = res2 + s.mean(space)**2
res1 = res1/len(samples)
res2 = res2/len(samples)
res = res1.mean() - res2.mean()
return np.sqrt(res)
res = res1.mean(spaces) - res2.mean(spaces[:-1])
if np.isscalar(res):
return np.sqrt(res)
return np.sqrt(res.val)
@staticmethod
def average_fluctuation_realized(samples, space):
"""Computes average fluctuations from collection of field (defined in signal
space) realizations."""
ldom = len(samples[0].domain)
if space >= ldom:
spaces = _structured_spaces(samples[0].domain)
if space >= len(spaces):
raise ValueError("invalid space specified; got {!r}".format(space))
if ldom == 1:
if len(spaces) == 1:
return _total_fluctuation_realized(samples)
spaces = ()
for i in range(ldom):
if i != space:
spaces += (i,)
space = space + spaces[0]
sub_spaces = set(spaces)
sub_spaces.remove(space)
sub_dom = makeDomain([samples[0].domain[ind]
for ind in set([0,]) | set([space,])])
co = ContractionOperator(sub_dom, len(sub_dom)-1)
res = 0.
for s in samples:
r = s.mean(spaces)
res = res + (r - r.mean())**2
res = res/len(samples)
return np.sqrt(res.mean())
r = s.mean(sub_spaces)
if min(spaces) == 0:
res = res + (r - r.mean(spaces[:-1]))**2
else:
res = res + (r - co.adjoint(r.mean(spaces[:-1])))**2
res = res.mean(spaces[0])/len(samples)
if np.isscalar(res):
return np.sqrt(res)
return np.sqrt(res.val)
......@@ -31,7 +31,8 @@ import nifty6 as ift
@pytest.mark.parametrize('rseed', [13, 2])
@pytest.mark.parametrize('Astds', [[1., 3.], [0.2, 1.4]])
@pytest.mark.parametrize('offset_std', [1., 10.])
def testAmplitudesConsistency(rseed, sspace, Astds, offset_std):
@pytest.mark.parametrize('N', [0,2])
def testAmplitudesConsistency(rseed, sspace, Astds, offset_std, N):
def stats(op, samples):
sc = ift.StatCalculator()
for s in samples:
......@@ -42,17 +43,23 @@ def testAmplitudesConsistency(rseed, sspace, Astds, offset_std):
nsam = 100
fsspace = ift.RGSpace((12,), (0.4,))
if N==2:
dofdex1 = [0,0]
dofdex2 = [1,0]
dofdex3 = [1,1]
else:
dofdex1, dofdex2, dofdex3 = None, None, None
fa = ift.CorrelatedFieldMaker.make(offset_std, 1E-8, '')
fa = ift.CorrelatedFieldMaker.make(offset_std, 1E-8, '', N, dofdex1)
fa.add_fluctuations(sspace, Astds[0], 1E-8, 1.1, 2., 2.1, .5, -2, 1.,
'spatial')
'spatial', dofdex = dofdex2)
fa.add_fluctuations(fsspace, Astds[1], 1E-8, 3.1, 1., .5, .1, -4, 1.,
'freq')
'freq', dofdex = dofdex3)
op = fa.finalize()
samples = [ift.from_random('normal', op.domain) for _ in range(nsam)]
tot_flm, _ = stats(fa.total_fluctuation, samples)
offset_std, _ = stats(fa.amplitude_total_offset, samples)
offset_amp_std, _ = stats(fa.amplitude_total_offset, samples)
intergated_fluct_std0, _ = stats(fa.average_fluctuation(0), samples)
intergated_fluct_std1, _ = stats(fa.average_fluctuation(1), samples)
......@@ -67,22 +74,23 @@ def testAmplitudesConsistency(rseed, sspace, Astds, offset_std):
sl_fluct_space = fa.slice_fluctuation_realized(sams, 0)
sl_fluct_freq = fa.slice_fluctuation_realized(sams, 1)
assert_allclose(offset_std, zm_std_mean, rtol=0.5)
assert_allclose(offset_amp_std, zm_std_mean, rtol=0.5)
assert_allclose(intergated_fluct_std0, fluct_space, rtol=0.5)
assert_allclose(intergated_fluct_std1, fluct_freq, rtol=0.5)
assert_allclose(tot_flm, fluct_total, rtol=0.5)
assert_allclose(slice_fluct_std0, sl_fluct_space, rtol=0.5)
assert_allclose(slice_fluct_std1, sl_fluct_freq, rtol=0.5)
fa = ift.CorrelatedFieldMaker.make(offset_std, .1, '')
fa.add_fluctuations(fsspace, Astds[1], 1., 3.1, 1., .5, .1, -4, 1., 'freq')
fa = ift.CorrelatedFieldMaker.make(offset_std, .1, '', N, dofdex1)
fa.add_fluctuations(fsspace, Astds[1], 1., 3.1, 1., .5, .1, -4, 1., 'freq', dofdex = dofdex3)
m = 3.
x = fa.moment_slice_to_average(m)
fa.add_fluctuations(sspace, x, 1.5, 1.1, 2., 2.1, .5, -2, 1., 'spatial', 0)
fa.add_fluctuations(sspace, x, 1.5, 1.1, 2., 2.1, .5, -2, 1., 'spatial', 0, dofdex = dofdex2)
op = fa.finalize()
em, estd = stats(fa.slice_fluctuation(0), samples)
assert_allclose(m, em, rtol=0.5)
assert op.target[0] == sspace
assert op.target[1] == fsspace
assert op.target[-2] == sspace
assert op.target[-1] == fsspace
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment