Skip to content
Snippets Groups Projects
Select Git revision
  • f06df48595d34cf61f49f88d66e61637ac46113b
  • master default protected
  • feature/particle_state_generation_with_variable_box_size
  • feature/add-fft-interface
  • feature/expose-rnumber-from-simulations
  • feature/forcing-unit-test
  • feature/dealias-check2
  • bugfix/check_field_exists
  • feature/dealias-check
  • v3.x
  • feature/particles-vectorization
  • 6.2.4
  • 6.2.3
  • 6.2.2
  • 6.2.1
  • 6.2.0
  • 6.1.0
  • 6.0.0
  • 5.8.1
  • 5.8.0
  • 5.7.2
  • 5.7.1
  • 5.7.0
  • 5.6.0
  • 5.5.1
  • 5.5.0
  • 5.4.7
  • 5.4.6
  • 5.4.5
  • 5.4.4
  • 5.4.3
31 results

joint_acc_vel_stats.hpp

Blame
  • 0_intro.py 5.10 KiB
    #!/usr/bin/env python3
    
    # Copyright(C) 2013-2021 Max-Planck-Society
    # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
    
    # %% [markdown]
    # # Demonstration of the non-parametric correlated field model in NIFTy.re
    
    # ## The Model
    
    # %%
    import jax
    import matplotlib.pyplot as plt
    import nifty8.re as jft
    from jax import numpy as jnp
    from jax import random
    
    jax.config.update("jax_enable_x64", True)
    
    seed = 42
    key = random.PRNGKey(seed)
    
    dims = (128, 128)
    
    cf_zm = dict(offset_mean=0.0, offset_std=(1e-3, 1e-4))
    cf_fl = dict(
        fluctuations=(1e-1, 5e-3),
        loglogavgslope=(-1.0, 1e-2),
        flexibility=(1e0, 5e-1),
        asperity=(5e-1, 5e-2),
    )
    cfm = jft.CorrelatedFieldMaker("cf")
    cfm.set_amplitude_total_offset(**cf_zm)
    cfm.add_fluctuations(
        dims, distances=1.0 / dims[0], **cf_fl, prefix="ax1", non_parametric_kind="power"
    )
    correlated_field = cfm.finalize()
    
    scaling = jft.LogNormalPrior(3.0, 1.0, name="scaling", shape=(1,))
    
    
    class Signal(jft.Model):
        def __init__(self, correlated_field, scaling):
            self.cf = correlated_field
            self.scaling = scaling
            # Init methods of the Correlated Field model and any prior model in
            # NIFTy.re are aware that their input is standard normal a priori.
            # The `domain` of a model does not know this. Thus, tracking the `init`
            # methods should be preferred over tracking the `domain`.
            super().__init__(init=self.cf.init | self.scaling.init)
    
        def __call__(self, x):
            # NOTE, think of `Model` as being just a plain function that takes some
            # input and performs all the necessary computation for your model.
            # Note, `scaling` here is completely degenarate with `offset_std` in the
            # likelihood but the priors for them are very different.
            return self.scaling(x) * jnp.exp(self.cf(x))
    
    
    signal = Signal(correlated_field, scaling)
    
    # %% [markdown]
    # ## The likelihood
    
    # %%
    signal_response = signal
    noise_cov = lambda x: 0.1**2 * x
    noise_cov_inv = lambda x: 0.1**-2 * x
    
    # Create synthetic data
    key, subkey = random.split(key)
    pos_truth = jft.random_like(subkey, signal_response.domain)
    signal_response_truth = signal_response(pos_truth)
    key, subkey = random.split(key)
    noise_truth = (
        (noise_cov(jft.ones_like(signal_response.target))) ** 0.5
    ) * jft.random_like(key, signal_response.target)
    data = signal_response_truth + noise_truth
    
    lh = jft.Gaussian(data, noise_cov_inv).amend(signal_response)
    
    # %% [markdown]
    # ## The inference
    
    # %%
    n_vi_iterations = 6
    delta = 1e-4
    n_samples = 4
    
    key, k_i, k_o = random.split(key, 3)
    # NOTE, changing the number of samples always triggers a resampling even if
    # `resamples=False`, as more samples have to be drawn that did not exist before.
    samples, state = jft.optimize_kl(
        lh,
        jft.Vector(lh.init(k_i)),
        n_total_iterations=n_vi_iterations,
        n_samples=lambda i: n_samples // 2 if i < 2 else n_samples,
        # Source for the stochasticity for sampling
        key=k_o,
        # Names of parameters that should not be sampled but still optimized
        # can be specified as point_estimates (effectively we are doing MAP for
        # these degrees of freedom).
        # point_estimates=("cfax1flexibility", "cfax1asperity"),
        # Arguments for the conjugate gradient method used to drawing samples
        draw_linear_kwargs=dict(
            cg_name="SL",
            cg_kwargs=dict(absdelta=delta * jft.size(lh.domain) / 10.0, maxiter=100),
        ),
        # Arguements for the minimizer in the nonlinear updating of the samples
        nonlinearly_update_kwargs=dict(
            minimize_kwargs=dict(
                name="SN",
                xtol=delta,
                cg_kwargs=dict(name=None),
                maxiter=5,
            )
        ),
        # Arguments for the minimizer of the KL-divergence cost potential
        kl_kwargs=dict(
            minimize_kwargs=dict(
                name="M", xtol=delta, cg_kwargs=dict(name=None), maxiter=35
            )
        ),
        sample_mode="nonlinear_resample",
        odir="results_intro",
        resume=False,
    )
    
    # %%
    namps = cfm.get_normalized_amplitudes()
    post_sr_mean = jft.mean(tuple(signal(s) for s in samples))
    post_a_mean = jft.mean(tuple(cfm.amplitude(s)[1:] for s in samples))
    grid = correlated_field.target_grids[0]
    to_plot = [
        ("Signal", signal(pos_truth), "im"),
        ("Noise", noise_truth, "im"),
        ("Data", data, "im"),
        ("Reconstruction", post_sr_mean, "im"),
        (
            "Amplitude spectrum",
            (
                grid.harmonic_grid.mode_lengths[1:],
                cfm.amplitude(pos_truth)[1:],
                post_a_mean,
            ),
            "loglog",
        ),
    ]
    fig, axs = plt.subplots(2, 3, figsize=(16, 9))
    for ax, v in zip(axs.flat, to_plot):
        title, field, tp, *labels = v
        ax.set_title(title)
        if tp == "im":
            end = tuple(n * d for n, d in zip(grid.shape, grid.distances))
            im = ax.imshow(field.T, cmap="inferno", extent=(0.0, end[0], 0.0, end[1]))
            plt.colorbar(im, ax=ax, orientation="horizontal")
        else:
            ax_plot = ax.loglog if tp == "loglog" else ax.plot
            x = field[0]
            for f in field[1:]:
                ax_plot(x, f, alpha=0.7)
    for ax in axs.flat[len(to_plot) :]:
        ax.set_axis_off()
    fig.tight_layout()
    fig.savefig("results_intro_full_reconstruction.png", dpi=400)
    plt.show()