Commit 66a5e583 authored by Philipp Arras's avatar Philipp Arras

Migrate to simple CorrelatedField model

Use standard posterior sampling strategy
parent 59f7c2c0
Pipeline #77388 passed with stages
in 13 minutes and 21 seconds
......@@ -52,20 +52,17 @@ def main():
else:
mode = 0
filename = "getting_started_3_mode_{}_".format(mode) + "{}.png"
position_space = ift.RGSpace([128, 128])
# For a detailed showcase of the effects the parameters
# of the CorrelatedField model have on the generated fields,
# see 'getting_started_4_CorrelatedFields.ipynb'.
cfmaker = ift.CorrelatedFieldMaker.make(
offset_mean= 0.0,
offset_std_mean= 1e-3,
offset_std_std= 1e-6,
prefix='')
args = {
'offset_mean': 0,
'offset_std_mean': 1e-3,
'offset_std_std': 1e-6,
fluctuations_dict = {
# Amplitude of field fluctuations
'fluctuations_mean': 2.0, # 1.0
'fluctuations_stddev': 1.0, # 1e-2
......@@ -82,10 +79,9 @@ def main():
'asperity_mean': 0.5, # 0.1
'asperity_stddev': 0.5 # 0.5
}
cfmaker.add_fluctuations(position_space, **fluctuations_dict)
correlated_field = cfmaker.finalize()
A = cfmaker.amplitude
correlated_field = ift.SimpleCorrelatedField(position_space, **args)
A = correlated_field.amplitude
# Apply a nonlinearity
signal = ift.sigmoid(correlated_field)
......@@ -143,8 +139,6 @@ def main():
plot.output(ny=1, ysize=6, xsize=16,
name=filename.format("loop_{:02d}".format(i)))
# Draw posterior samples
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator()
for sample in KL.samples:
sc.add(signal(sample + KL.position))
......@@ -156,11 +150,15 @@ def main():
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
powers = [A.force(s + KL.position) for s in KL.samples]
sc = ift.StatCalculator()
for pp in powers:
sc.add(pp)
plot.add(
powers + [A.force(mock_position),
A.force(KL.position)],
A.force(KL.position), sc.mean],
title="Sampled Posterior Power Spectrum",
linewidth=[1.]*len(powers) + [3., 3.])
linewidth=[1.]*len(powers) + [3., 3., 3.],
label=[None]*len(powers) + ['Ground truth', 'Posterior latent mean', 'Posterior mean'])
plot.output(ny=1, nx=3, xsize=24, ysize=6, name=filename_res)
print("Saved results as '{}'.".format(filename_res))
......
......@@ -105,3 +105,7 @@ class SimpleCorrelatedField(Operator):
self.apply = op.apply
self._domain = op.domain
self._target = op.target
@property
def amplitude(self):
return self._a
......@@ -172,7 +172,7 @@ def test_complicated_vs_simple(seed, domain):
ift.extra.assert_allclose(scf(inp), op1(inp))
ift.extra.check_operator(scf, inp, ntries=10)
# op1 = cfm.amplitude
# op0 = scf.amplitude
# assert_(op0.domain is op1.domain)
# ift.extra.assert_allclose(op0.force(inp), op1.force(inp))
op1 = cfm.amplitude
op0 = scf.amplitude
assert_(op0.domain is op1.domain)
ift.extra.assert_allclose(op0.force(inp), op1.force(inp))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment