diff --git a/demos/getting_started_density.py b/demos/getting_started_density.py index 55be5200a2050c92977f70026bcffa5003851d5f..f4f92306a56f1654ed34e6fff48ede75e99b1090 100644 --- a/demos/getting_started_density.py +++ b/demos/getting_started_density.py @@ -31,9 +31,7 @@ import numpy as np import nifty7 as ift -def density_estimator( - domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None -): +def density_estimator(domain, pad=1.0, cf_fluctuations=None, cf_azm_uniform=None): cf_azm_uniform_sane_default = (1e-4, 1.0) cf_fluctuations_sane_default = { "scale": (0.5, 0.3), @@ -109,40 +107,34 @@ if __name__ == "__main__": # Generate mock signal and data rng = ift.random.current_rng() - mock_position = ift.from_random(signal.domain, 'normal') - data = ift.Field.from_raw( - data_space, rng.poisson(signal(mock_position).val) - ) + mock_position = ift.from_random(signal.domain, "normal") + data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val)) # Rejoin domains for plotting plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2))) - plotting_domain_expanded = ift.DomainTuple.make( - ift.RGSpace((2 * npix1, 2 * npix2)) - ) + plotting_domain_expanded = ift.DomainTuple.make(ift.RGSpace((2 * npix1, 2 * npix2))) plot = ift.Plot() plot.add( ift.Field.from_raw( - plotting_domain_expanded, - ift.exp(correlated_field(mock_position)).val + plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val ), - title='Pre-Slicing Truth' + title="Pre-Slicing Truth", ) plot.add( - ift.Field.from_raw(plotting_domain, - signal(mock_position).val), - title='Ground Truth' + ift.Field.from_raw(plotting_domain, signal(mock_position).val), + title="Ground Truth", ) - plot.add(ift.Field.from_raw(plotting_domain, data.val), title='Data') + plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data") plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup")) print("Setup saved as", filename.format("setup")) # Minimization parameters ic_sampling = ift.AbsDeltaEnergyController( - name='Sampling', deltaE=0.01, iteration_limit=100 + name="Sampling", deltaE=0.01, iteration_limit=100 ) ic_newton = ift.AbsDeltaEnergyController( - name='Newton', deltaE=0.01, iteration_limit=35 + name="Newton", deltaE=0.01, iteration_limit=35 ) ic_sampling.enable_logging() ic_newton.enable_logging() @@ -169,37 +161,27 @@ if __name__ == "__main__": plot = ift.Plot() plot.add( ift.Field.from_raw( - plotting_domain_expanded, - ift.exp(correlated_field(mock_position)).val + plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val ), - title="Ground truth" + title="Ground truth", ) plot.add( - ift.Field.from_raw(plotting_domain, - signal(mock_position).val), - title="Ground truth" + ift.Field.from_raw(plotting_domain, signal(mock_position).val), + title="Ground truth", ) plot.add( - ift.Field.from_raw(plotting_domain, - signal(kl.position).val), - title="Reconstruction" + ift.Field.from_raw(plotting_domain, signal(kl.position).val), + title="Reconstruction", ) plot.add( - ( - ic_newton.history, ic_sampling.history, - minimizer.inversion_history - ), - label=['kl', 'Sampling', 'Newton inversion'], - title='Cumulative energies', + (ic_newton.history, ic_sampling.history, minimizer.inversion_history), + label=["kl", "Sampling", "Newton inversion"], + title="Cumulative energies", s=[None, None, 1], - alpha=[None, 0.2, None] + alpha=[None, 0.2, None], ) plot.output( - nx=3, - ny=2, - ysize=10, - xsize=15, - name=filename.format(f"loop_{i:02d}") + nx=3, ny=2, ysize=10, xsize=15, name=filename.format(f"loop_{i:02d}") ) # Done, draw posterior samples @@ -211,25 +193,18 @@ if __name__ == "__main__": # Plotting plot = ift.Plot() + plot.add(ift.Field.from_raw(plotting_domain, sc.mean.val), title="Posterior Mean") plot.add( - ift.Field.from_raw(plotting_domain, sc.mean.val), - title="Posterior Mean" - ) - plot.add( - ift.Field.from_raw(plotting_domain, - ift.sqrt(sc.var).val), - title="Posterior Standard Deviation" + ift.Field.from_raw(plotting_domain, ift.sqrt(sc.var).val), + title="Posterior Standard Deviation", ) plot.add( ift.Field.from_raw(plotting_domain_expanded, sc_unsliced.mean.val), - title="Posterior Unsliced Mean" + title="Posterior Unsliced Mean", ) plot.add( - ift.Field.from_raw( - plotting_domain_expanded, - ift.sqrt(sc_unsliced.var).val - ), - title="Posterior Unsliced Standard Deviation" + ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(sc_unsliced.var).val), + title="Posterior Unsliced Standard Deviation", ) filename_res = filename.format("results") plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res)