Commit b98cb6b5 authored by Philipp Arras's avatar Philipp Arras
Browse files

Allow longer lines

parent 0c37585c
......@@ -31,9 +31,7 @@ import numpy as np
import nifty7 as ift
def density_estimator(
domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None
):
def density_estimator(domain, pad=1.0, cf_fluctuations=None, cf_azm_uniform=None):
cf_azm_uniform_sane_default = (1e-4, 1.0)
cf_fluctuations_sane_default = {
"scale": (0.5, 0.3),
......@@ -109,40 +107,34 @@ if __name__ == "__main__":
# Generate mock signal and data
rng = ift.random.current_rng()
mock_position = ift.from_random(signal.domain, 'normal')
data = ift.Field.from_raw(
data_space, rng.poisson(signal(mock_position).val)
)
mock_position = ift.from_random(signal.domain, "normal")
data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val))
# Rejoin domains for plotting
plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2)))
plotting_domain_expanded = ift.DomainTuple.make(
ift.RGSpace((2 * npix1, 2 * npix2))
)
plotting_domain_expanded = ift.DomainTuple.make(ift.RGSpace((2 * npix1, 2 * npix2)))
plot = ift.Plot()
plot.add(
ift.Field.from_raw(
plotting_domain_expanded,
ift.exp(correlated_field(mock_position)).val
plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
),
title='Pre-Slicing Truth'
title="Pre-Slicing Truth",
)
plot.add(
ift.Field.from_raw(plotting_domain,
signal(mock_position).val),
title='Ground Truth'
ift.Field.from_raw(plotting_domain, signal(mock_position).val),
title="Ground Truth",
)
plot.add(ift.Field.from_raw(plotting_domain, data.val), title='Data')
plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data")
plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup"))
print("Setup saved as", filename.format("setup"))
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(
name='Sampling', deltaE=0.01, iteration_limit=100
name="Sampling", deltaE=0.01, iteration_limit=100
)
ic_newton = ift.AbsDeltaEnergyController(
name='Newton', deltaE=0.01, iteration_limit=35
name="Newton", deltaE=0.01, iteration_limit=35
)
ic_sampling.enable_logging()
ic_newton.enable_logging()
......@@ -169,37 +161,27 @@ if __name__ == "__main__":
plot = ift.Plot()
plot.add(
ift.Field.from_raw(
plotting_domain_expanded,
ift.exp(correlated_field(mock_position)).val
plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
),
title="Ground truth"
title="Ground truth",
)
plot.add(
ift.Field.from_raw(plotting_domain,
signal(mock_position).val),
title="Ground truth"
ift.Field.from_raw(plotting_domain, signal(mock_position).val),
title="Ground truth",
)
plot.add(
ift.Field.from_raw(plotting_domain,
signal(kl.position).val),
title="Reconstruction"
ift.Field.from_raw(plotting_domain, signal(kl.position).val),
title="Reconstruction",
)
plot.add(
(
ic_newton.history, ic_sampling.history,
minimizer.inversion_history
),
label=['kl', 'Sampling', 'Newton inversion'],
title='Cumulative energies',
(ic_newton.history, ic_sampling.history, minimizer.inversion_history),
label=["kl", "Sampling", "Newton inversion"],
title="Cumulative energies",
s=[None, None, 1],
alpha=[None, 0.2, None]
alpha=[None, 0.2, None],
)
plot.output(
nx=3,
ny=2,
ysize=10,
xsize=15,
name=filename.format(f"loop_{i:02d}")
nx=3, ny=2, ysize=10, xsize=15, name=filename.format(f"loop_{i:02d}")
)
# Done, draw posterior samples
......@@ -211,25 +193,18 @@ if __name__ == "__main__":
# Plotting
plot = ift.Plot()
plot.add(ift.Field.from_raw(plotting_domain, sc.mean.val), title="Posterior Mean")
plot.add(
ift.Field.from_raw(plotting_domain, sc.mean.val),
title="Posterior Mean"
)
plot.add(
ift.Field.from_raw(plotting_domain,
ift.sqrt(sc.var).val),
title="Posterior Standard Deviation"
ift.Field.from_raw(plotting_domain, ift.sqrt(sc.var).val),
title="Posterior Standard Deviation",
)
plot.add(
ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val),
title="Posterior Unsliced Mean"
title="Posterior Unsliced Mean",
)
plot.add(
ift.Field.from_raw(
plotting_domain_expanded,
ift.sqrt(sc_unsliced.var).val
),
title="Posterior Unsliced Standard Deviation"
ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(sc_unsliced.var).val),
title="Posterior Unsliced Standard Deviation",
)
filename_res = filename.format("results")
plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment