diff --git a/demos/getting_started_density.py b/demos/getting_started_density.py index 62921ead7690670a146579c6b36ceeedf4a4583b..4a5c86a93a572e4af47d617d990e4da6c4e5fa2a 100644 --- a/demos/getting_started_density.py +++ b/demos/getting_started_density.py @@ -32,8 +32,8 @@ import nifty7 as ift def density_estimator( - domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None - ): + domain, pad=1., cf_fluctuations=None, cf_azm_uniform=None +): cf_azm_uniform_sane_default = (0., 20.) cf_fluctuations_sane_default = { "scale": (0.5, 0.3), @@ -57,13 +57,12 @@ def density_estimator( ) raise TypeError(te) shape_padded = tuple((d_scl * np.array(d.shape)).astype(int)) - domain_padded.append( - ift.RGSpace(shape_padded, distances=d.distances) - ) + domain_padded.append(ift.RGSpace(shape_padded, distances=d.distances)) domain_padded = ift.DomainTuple.make(domain_padded) # Set up the signal model - prefix = "de" # density estimator + prefix = "de_" # density estimator + azm_offset_mean = 0. # The zero-mode should be inferred only from the data cfmaker = ift.CorrelatedFieldMaker(prefix) for i, d in enumerate(domain_padded): if isinstance(cf_fluctuations, (list, tuple)): @@ -74,7 +73,6 @@ def density_estimator( scalar_domain = ift.DomainTuple.scalar_domain() uniform = ift.UniformOperator(scalar_domain, *cf_azm_uni) azm = uniform.ducktape("zeromode") - azm_offset_mean = 0. # The zero-mode should be inferred only from the data cfmaker.set_amplitude_total_offset(azm_offset_mean, azm) correlated_field = cfmaker.finalize(0) normalized_amplitudes = cfmaker.get_normalized_amplitudes() @@ -109,21 +107,25 @@ if __name__ == "__main__": rng = ift.random.current_rng() rng.standard_normal(1000) mock_position = ift.from_random(signal.domain, 'normal') - data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val)) + data = ift.Field.from_raw( + data_space, rng.poisson(signal(mock_position).val) + ) plot = ift.Plot() - plot.add(ift.exp(correlated_field(mock_position)), title='Pre-Slicing Truth') + plot.add( + ift.exp(correlated_field(mock_position)), title='Pre-Slicing Truth' + ) plot.add(signal(mock_position), title='Ground Truth') plot.add(data, title='Data') plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup")) # Minimization parameters - ic_sampling = ift.AbsDeltaEnergyController(name='Sampling', - deltaE=0.01, - iteration_limit=100) - ic_newton = ift.AbsDeltaEnergyController(name='Newton', - deltaE=0.01, - iteration_limit=35) + ic_sampling = ift.AbsDeltaEnergyController( + name='Sampling', deltaE=0.01, iteration_limit=100 + ) + ic_newton = ift.AbsDeltaEnergyController( + name='Newton', deltaE=0.01, iteration_limit=35 + ) ic_sampling.enable_logging() ic_newton.enable_logging() minimizer = ift.NewtonCG(ic_newton, enable_logging=True) @@ -150,16 +152,23 @@ if __name__ == "__main__": plot.add(ift.exp(correlated_field(mock_position)), title="ground truth") plot.add(signal(mock_position), title="ground truth") plot.add(signal(kl.position), title="reconstruction") - plot.add((ic_newton.history, ic_sampling.history, - minimizer.inversion_history), - label=['kl', 'Sampling', 'Newton inversion'], - title='Cumulative energies', s=[None, None, 1], - alpha=[None, 0.2, None]) - plot.output(nx=3, - ny=2, - ysize=10, - xsize=15, - name=filename.format(f"loop_{i:02d}")) + plot.add( + ( + ic_newton.history, ic_sampling.history, + minimizer.inversion_history + ), + label=['kl', 'Sampling', 'Newton inversion'], + title='Cumulative energies', + s=[None, None, 1], + alpha=[None, 0.2, None] + ) + plot.output( + nx=3, + ny=2, + ysize=10, + xsize=15, + name=filename.format(f"loop_{i:02d}") + ) # Done, draw posterior samples sc = ift.StatCalculator() @@ -174,7 +183,10 @@ if __name__ == "__main__": plot.add(sc.mean, title="Posterior Mean") plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation") plot.add(sc_unsliced.mean, title="Posterior Unsliced Mean") - plot.add(ift.sqrt(sc_unsliced.var), title="Posterior Unsliced Standard Deviation") + plot.add( + ift.sqrt(sc_unsliced.var), + title="Posterior Unsliced Standard Deviation" + ) plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res) print("Saved results as '{}'.".format(filename_res))