Skip to content
Snippets Groups Projects
Commit 092d641f authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

getting_started_density.py: Format

parent 845f93c1
No related branches found
No related tags found
1 merge request!607Revamp zero-mode handling
......@@ -32,8 +32,8 @@ import nifty7 as ift
def density_estimator(
domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None
):
domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None
):
cf_azm_uniform_sane_default = (0., 20.)
cf_fluctuations_sane_default = {
"scale": (0.5, 0.3),
......@@ -57,13 +57,12 @@ def density_estimator(
)
raise TypeError(te)
shape_padded = tuple((d_scl * np.array(d.shape)).astype(int))
domain_padded.append(
ift.RGSpace(shape_padded, distances=d.distances)
)
domain_padded.append(ift.RGSpace(shape_padded, distances=d.distances))
domain_padded = ift.DomainTuple.make(domain_padded)
# Set up the signal model
prefix = "de" # density estimator
prefix = "de_" # density estimator
azm_offset_mean = 0. # The zero-mode should be inferred only from the data
cfmaker = ift.CorrelatedFieldMaker(prefix)
for i, d in enumerate(domain_padded):
if isinstance(cf_fluctuations, (list, tuple)):
......@@ -74,7 +73,6 @@ def density_estimator(
scalar_domain = ift.DomainTuple.scalar_domain()
uniform = ift.UniformOperator(scalar_domain, *cf_azm_uni)
azm = uniform.ducktape("zeromode")
azm_offset_mean = 0. # The zero-mode should be inferred only from the data
cfmaker.set_amplitude_total_offset(azm_offset_mean, azm)
correlated_field = cfmaker.finalize(0)
normalized_amplitudes = cfmaker.get_normalized_amplitudes()
......@@ -109,21 +107,25 @@ if __name__ == "__main__":
rng = ift.random.current_rng()
rng.standard_normal(1000)
mock_position = ift.from_random(signal.domain, 'normal')
data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val))
data = ift.Field.from_raw(
data_space, rng.poisson(signal(mock_position).val)
)
plot = ift.Plot()
plot.add(ift.exp(correlated_field(mock_position)), title='Pre-Slicing Truth')
plot.add(
ift.exp(correlated_field(mock_position)), title='Pre-Slicing Truth'
)
plot.add(signal(mock_position), title='Ground Truth')
plot.add(data, title='Data')
plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup"))
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(name='Sampling',
deltaE=0.01,
iteration_limit=100)
ic_newton = ift.AbsDeltaEnergyController(name='Newton',
deltaE=0.01,
iteration_limit=35)
ic_sampling = ift.AbsDeltaEnergyController(
name='Sampling', deltaE=0.01, iteration_limit=100
)
ic_newton = ift.AbsDeltaEnergyController(
name='Newton', deltaE=0.01, iteration_limit=35
)
ic_sampling.enable_logging()
ic_newton.enable_logging()
minimizer = ift.NewtonCG(ic_newton, enable_logging=True)
......@@ -150,16 +152,23 @@ if __name__ == "__main__":
plot.add(ift.exp(correlated_field(mock_position)), title="ground truth")
plot.add(signal(mock_position), title="ground truth")
plot.add(signal(kl.position), title="reconstruction")
plot.add((ic_newton.history, ic_sampling.history,
minimizer.inversion_history),
label=['kl', 'Sampling', 'Newton inversion'],
title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
plot.output(nx=3,
ny=2,
ysize=10,
xsize=15,
name=filename.format(f"loop_{i:02d}"))
plot.add(
(
ic_newton.history, ic_sampling.history,
minimizer.inversion_history
),
label=['kl', 'Sampling', 'Newton inversion'],
title='Cumulative energies',
s=[None, None, 1],
alpha=[None, 0.2, None]
)
plot.output(
nx=3,
ny=2,
ysize=10,
xsize=15,
name=filename.format(f"loop_{i:02d}")
)
# Done, draw posterior samples
sc = ift.StatCalculator()
......@@ -174,7 +183,10 @@ if __name__ == "__main__":
plot.add(sc.mean, title="Posterior Mean")
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
plot.add(sc_unsliced.mean, title="Posterior Unsliced Mean")
plot.add(ift.sqrt(sc_unsliced.var), title="Posterior Unsliced Standard Deviation")
plot.add(
ift.sqrt(sc_unsliced.var),
title="Posterior Unsliced Standard Deviation"
)
plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)
print("Saved results as '{}'.".format(filename_res))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment