diff --git a/demos/getting_started_density.py b/demos/getting_started_density.py index f4f92306a56f1654ed34e6fff48ede75e99b1090..3ea2f0741e36df7cd5208739a6f4a1d7a11e713f 100644 --- a/demos/getting_started_density.py +++ b/demos/getting_started_density.py @@ -49,19 +49,16 @@ def density_estimator(domain, pad=1.0, cf_fluctuations=None, cf_azm_uniform=None domain_padded = [] for d_scl, d in zip(dom_scaling, domain): if not isinstance(d, ift.RGSpace) or d.harmonic: - te = ( - f"unexpected domain encountered in `domain`: {domain}\n" - "expected a non-harmonic `ift.RGSpace`" - ) - raise TypeError(te) + te = [f"unexpected domain encountered in `domain`: {domain}"] + te += "expected a non-harmonic `ift.RGSpace`" + raise TypeError("\n".join(te)) shape_padded = tuple((d_scl * np.array(d.shape)).astype(int)) domain_padded.append(ift.RGSpace(shape_padded, distances=d.distances)) domain_padded = ift.DomainTuple.make(domain_padded) # Set up the signal model - prefix = "de_" # density estimator - azm_offset_mean = 0. # The zero-mode should be inferred only from the data - cfmaker = ift.CorrelatedFieldMaker(prefix) + azm_offset_mean = 0.0 # The zero-mode should be inferred only from the data + cfmaker = ift.CorrelatedFieldMaker("") for i, d in enumerate(domain_padded): if isinstance(cf_fluctuations, (list, tuple)): cf_fl = cf_fluctuations[i] @@ -126,7 +123,7 @@ if __name__ == "__main__": title="Ground Truth", ) plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data") - plot.output(ny=1, nx=3, xsize=10, ysize=10, name=filename.format("setup")) + plot.output(ny=1, nx=3, xsize=10, ysize=3, name=filename.format("setup")) print("Setup saved as", filename.format("setup")) # Minimization parameters @@ -206,6 +203,5 @@ if __name__ == "__main__": ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(sc_unsliced.var).val), title="Posterior Unsliced Standard Deviation", ) - filename_res = filename.format("results") - plot.output(ny=2, nx=2, xsize=15, ysize=15, name=filename_res) - print("Saved results as '{}'.".format(filename_res)) + plot.output(xsize=15, ysize=15, name=filename.format("results")) + print("Saved results as", filename.format("results"))