Skip to content
Snippets Groups Projects
Select Git revision
  • c4416f36c522dcbe66d9e7da622d04f7ea56c729
  • develop default protected
  • fuji-image
  • ellips_example_fix
  • 150-xps-example-update
  • sts_oasis
  • sts_examples
  • xps_eln_update
  • update-mpes-example
  • add-specsscan-fhi
  • xrd_example
  • downgrade-h5web
  • fixed-datascience-notebook-version-develop
  • fixed-datascience-notebook-version
  • ChangesForMigration
  • fhi-pcr840
  • add-nomad-lab-to-jupyterlab
  • fix_requirements
  • main protected
  • v0.0.2
  • v0.0.1
  • with-north-api
22 results

config.py

Blame
  • density_estimation.py 5.02 KiB
    #!/usr/bin/env python3
    
    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 3 of the License, or
    # (at your option) any later version.
    #
    # This program is distributed in the hope that it will be useful,
    # but WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    # GNU General Public License for more details.
    #
    # You should have received a copy of the GNU General Public License
    # along with this program.  If not, see <http://www.gnu.org/licenses/>.
    #
    # Copyright(C) 2013-2021 Max-Planck-Society
    #
    # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
    
    ############################################################
    # Density estimation
    #
    # Compute a density estimate for a log-normal process measured by a
    # Poissonian likelihood.
    #
    # Demo takes a while to compute
    #############################################################
    
    import nifty8 as ift
    
    if __name__ == "__main__":
        filename = "getting_started_density_{}.png"
        ift.random.push_sseq_from_seed(42)
    
        # Set up signal domain
        npix1 = 128
        npix2 = 128
        sp1 = ift.RGSpace(npix1)
        sp2 = ift.RGSpace(npix2)
        position_space = ift.DomainTuple.make((sp1, sp2))
    
        signal, ops = ift.density_estimator(position_space)
        correlated_field = ops["correlated_field"]
    
        data_space = signal.target
    
        # Generate mock signal and data
        rng = ift.random.current_rng()
        mock_position = ift.from_random(signal.domain, "normal")
        data = ift.Field.from_raw(data_space, rng.poisson(signal(mock_position).val))
    
        # Rejoin domains for plotting
        plotting_domain = ift.DomainTuple.make(ift.RGSpace((npix1, npix2)))
        plotting_domain_expanded = ift.DomainTuple.make(ift.RGSpace((2 * npix1, 2 * npix2)))
    
        plot = ift.Plot()
        plot.add(
            ift.Field.from_raw(
                plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
            ),
            title="Pre-Slicing Truth",
        )
        plot.add(
            ift.Field.from_raw(plotting_domain, signal(mock_position).val),
            title="Ground Truth",
        )
        plot.add(ift.Field.from_raw(plotting_domain, data.val), title="Data")
        plot.output(ny=1, nx=3, xsize=10, ysize=3, name=filename.format("setup"))
        print("Setup saved as", filename.format("setup"))
    
        # Minimization parameters
        ic_sampling = ift.AbsDeltaEnergyController(
            name="Sampling", deltaE=0.01, iteration_limit=100
        )
        ic_newton = ift.AbsDeltaEnergyController(
            name="Newton", deltaE=0.01, iteration_limit=35
        )
        ic_sampling.enable_logging()
        ic_newton.enable_logging()
        minimizer = ift.NewtonCG(ic_newton, enable_logging=True)
    
        # Number of samples used to estimate the KL
        n_samples = 5
    
        # Set up likelihood energy and information Hamiltonian
        likelihood_energy = ift.PoissonianEnergy(data) @ signal
        ham = ift.StandardHamiltonian(likelihood_energy, ic_sampling)
    
        # Start minimization
        initial_mean = 1e-2 * ift.from_random(ham.domain)
        mean = initial_mean
    
        for i in range(3):
            # Draw new samples and minimize KL
            kl = ift.SampledKLEnergy(mean, ham, n_samples, None)
            kl, convergence = minimizer(kl)
            mean = kl.position
    
            # Plot current reconstruction
            plot = ift.Plot()
            plot.add(
                ift.Field.from_raw(
                    plotting_domain_expanded, ift.exp(correlated_field(mock_position)).val
                ),
                title="Ground truth",
            )
            plot.add(
                ift.Field.from_raw(plotting_domain, signal(mock_position).val),
                title="Ground truth",
            )
            plot.add(
                ift.Field.from_raw(plotting_domain, kl.samples.average(signal).val),
                title="Reconstruction",
            )
            plot.add(
                (ic_newton.history, ic_sampling.history, minimizer.inversion_history),
                label=["kl", "Sampling", "Newton inversion"],
                title="Cumulative energies",
                s=[None, None, 1],
                alpha=[None, 0.2, None],
            )
            plot.output(
                nx=3, ny=2, ysize=10, xsize=15, name=filename.format(f"loop_{i:02d}")
            )
    
        # Done, compute posterior statistics
        mean, var = kl.samples.sample_stat(signal)
        mean_unsliced, var_unsliced = kl.samples.sample_stat(correlated_field.exp())
    
        # Plotting
        plot = ift.Plot()
        plot.add(ift.Field.from_raw(plotting_domain, mean.val), title="Posterior Mean")
        plot.add(
            ift.Field.from_raw(plotting_domain, ift.sqrt(var).val),
            title="Posterior Standard Deviation",
        )
        plot.add(
            ift.Field.from_raw(plotting_domain_expanded, mean_unsliced.val),
            title="Posterior Unsliced Mean",
        )
        plot.add(
            ift.Field.from_raw(plotting_domain_expanded, ift.sqrt(var_unsliced).val),
            title="Posterior Unsliced Standard Deviation",
        )
        plot.output(xsize=15, ysize=15, name=filename.format("results"))
        print("Saved results as", filename.format("results"))