Commit b98cb6b5 authored by Philipp Arras's avatar Philipp Arras
Browse files

Allow longer lines

parent 0c37585c
...@@ -31,9 +31,7 @@ import numpy as np ...@@ -31,9 +31,7 @@ import numpy as np
import nifty7 as ift import nifty7 as ift
def density_estimator( def density_estimator(domain, pad=1.0, cf_fluctuations=None, cf_azm_uniform=None):
domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None
):
cf_azm_uniform_sane_default = (1e-4, 1.0) cf_azm_uniform_sane_default = (1e-4, 1.0)
cf_fluctuations_sane_default = { cf_fluctuations_sane_default = {
"scale": (0.5, 0.3), "scale": (0.5, 0.3),
...@@ -109,40 +107,34 @@ if __name__ == "__main__": ...@@ -109,40 +107,34 @@ if __name__ == "__main__":
# Generate mock signal and data # Generate mock signal and data
rng = ift.random.current_rng() rng = ift.random.current_rng()
mock_position = ift.from_random(signal.domain, 'normal') mock_position = ift.from_random(signal.domain, "normal")
data = ift.Field.from_raw( data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val))
data_space, rng.poisson(signal(mock_position).val)
)
# Rejoin domains for plotting # Rejoin domains for plotting
plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2))) plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2)))
plotting_domain_expanded = ift.DomainTuple.make( plotting_domain_expanded = ift.DomainTuple.make(ift.RGSpace((2 * npix1, 2 * npix2)))
ift.RGSpace((2 * npix1, 2 * npix2))
)
plot = ift.Plot() plot = ift.Plot()
plot.add( plot.add(
ift.Field.from_raw( ift.Field.from_raw(
plotting_domain_expanded, plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
ift.exp(correlated_field(mock_position)).val
), ),
title='Pre-Slicing Truth' title="Pre-Slicing Truth",
) )
plot.add( plot.add(
ift.Field.from_raw(plotting_domain, ift.Field.from_raw(plotting_domain, signal(mock_position).val),
signal(mock_position).val), title="Ground Truth",
title='Ground Truth'
) )
plot.add(ift.Field.from_raw(plotting_domain, data.val), title='Data') plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data")
plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup")) plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup"))
print("Setup saved as", filename.format("setup")) print("Setup saved as", filename.format("setup"))
# Minimization parameters # Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController( ic_sampling = ift.AbsDeltaEnergyController(
name='Sampling', deltaE=0.01, iteration_limit=100 name="Sampling", deltaE=0.01, iteration_limit=100
) )
ic_newton = ift.AbsDeltaEnergyController( ic_newton = ift.AbsDeltaEnergyController(
name='Newton', deltaE=0.01, iteration_limit=35 name="Newton", deltaE=0.01, iteration_limit=35
) )
ic_sampling.enable_logging() ic_sampling.enable_logging()
ic_newton.enable_logging() ic_newton.enable_logging()
...@@ -169,37 +161,27 @@ if __name__ == "__main__": ...@@ -169,37 +161,27 @@ if __name__ == "__main__":
plot = ift.Plot() plot = ift.Plot()
plot.add( plot.add(
ift.Field.from_raw( ift.Field.from_raw(
plotting_domain_expanded, plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
ift.exp(correlated_field(mock_position)).val
), ),
title="Ground truth" title="Ground truth",
) )
plot.add( plot.add(
ift.Field.from_raw(plotting_domain, ift.Field.from_raw(plotting_domain, signal(mock_position).val),
signal(mock_position).val), title="Ground truth",
title="Ground truth"
) )
plot.add( plot.add(
ift.Field.from_raw(plotting_domain, ift.Field.from_raw(plotting_domain, signal(kl.position).val),
signal(kl.position).val), title="Reconstruction",
title="Reconstruction"
) )
plot.add( plot.add(
( (ic_newton.history, ic_sampling.history, minimizer.inversion_history),
ic_newton.history, ic_sampling.history, label=["kl", "Sampling", "Newton inversion"],
minimizer.inversion_history title="Cumulative energies",
),
label=['kl', 'Sampling', 'Newton inversion'],
title='Cumulative energies',
s=[None, None, 1], s=[None, None, 1],
alpha=[None, 0.2, None] alpha=[None, 0.2, None],
) )
plot.output( plot.output(
nx=3, nx=3, ny=2, ysize=10, xsize=15, name=filename.format(f"loop_{i:02d}")
ny=2,
ysize=10,
xsize=15,
name=filename.format(f"loop_{i:02d}")
) )
# Done, draw posterior samples # Done, draw posterior samples
...@@ -211,25 +193,18 @@ if __name__ == "__main__": ...@@ -211,25 +193,18 @@ if __name__ == "__main__":
# Plotting # Plotting
plot = ift.Plot() plot = ift.Plot()
plot.add(ift.Field.from_raw(plotting_domain, sc.mean.val), title="Posterior Mean")
plot.add( plot.add(
ift.Field.from_raw(plotting_domain, sc.mean.val), ift.Field.from_raw(plotting_domain, ift.sqrt(sc.var).val),
title="Posterior Mean" title="Posterior Standard Deviation",
)
plot.add(
ift.Field.from_raw(plotting_domain,
ift.sqrt(sc.var).val),
title="Posterior Standard Deviation"
) )
plot.add( plot.add(
ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val), ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val),
title="Posterior Unsliced Mean" title="Posterior Unsliced Mean",
) )
plot.add( plot.add(
ift.Field.from_raw( ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(sc_unsliced.var).val),
plotting_domain_expanded, title="Posterior Unsliced Standard Deviation",
ift.sqrt(sc_unsliced.var).val
),
title="Posterior Unsliced Standard Deviation"
) )
filename_res = filename.format("results") filename_res = filename.format("results")
plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res) plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment