Skip to content
Snippets Groups Projects
Select Git revision
  • pos_dep_nfft
  • pytorch_operator
  • main default protected
  • qpo_model_rebased
  • native_extension
  • joint_re_cl_tests
  • re_fewer_tests
  • perf_tweaks
  • NIFTy_8 protected
  • fix_nonlinearity_gradients
  • cupy_backend
  • nifty
  • nifty8_philipps_unmerged_patches
  • nifty_jr
  • frequency_model
  • 423-minisanity-re-improve-likelihood-readability
  • 420-tracerboolconversion-error-in-lognormal_moments-py
  • change_ncg_default
  • fix_saved_pickle_files
  • inline_issue
  • 9.1.0 protected
  • 9.0.0 protected
  • v8.5.7 protected
  • v8.5.6 protected
  • v8.5.5 protected
  • v8.5.4 protected
  • v8.5.3 protected
  • v8.5.2 protected
  • v8.5.1 protected
  • v8.5 protected
  • v8.4 protected
  • v8.3 protected
  • v8.2 protected
  • v8.1 protected
  • v8.0 protected
  • v7.5 protected
  • v7.4 protected
  • v7.3 protected
  • v7.2 protected
  • v7.1 protected
40 results

a_nonlinear_regression.py

Blame
  • a_nonlinear_regression.py 2.61 KiB
    #!/usr/bin/env python3
    
    # Copyright(C) 2013-2021 Max-Planck-Society
    # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
    
    # %% [markdown]
    # # Demonstration of a non-linear regression using NIFTy.re
    
    # %%
    import operator as op
    from functools import partial
    
    import jax
    import matplotlib.pyplot as plt
    import nifty8.re as jft
    import numpy as np
    from jax import numpy as jnp
    from jax import random
    
    jax.config.update("jax_enable_x64", True)
    
    # %%
    seed = 42
    key = random.PRNGKey(seed)
    
    
    class NonLinearRegression(jft.Model):
        def __init__(self, slope_mean, slope_std, intercept_min, intercept_max, x):
            self.slope = jft.LogNormalPrior(slope_mean, slope_std, name="slope")
            self.intercept = jft.UniformPrior(
                intercept_min, intercept_max, name="intercept"
            )
            self.x = x
            super().__init__(init=self.slope.init | self.intercept.init)
    
        def __call__(self, xi, *, x=None):
            x = x if x is not None else self.x
            return x * self.slope(xi) + self.intercept(xi)
    
    
    key, sk = random.split(key)
    x = random.uniform(sk, (50,), float, -4.0, 4.0)
    nlr = NonLinearRegression(3.0, 2.0, -5.0, 5.0, x)
    
    # %%
    noise_std = 5
    
    key_tr, key_n = random.split(random.PRNGKey(31415))
    y = nlr(nlr.init(key_tr)) + noise_std * random.normal(key_n, x.shape, x.dtype)
    
    # %%
    fig, ax = plt.subplots()
    ax.plot(x, y, color="dodgerblue", linestyle="None", marker=".", markersize=8)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    fig.tight_layout()
    fig.savefig("nonlinear_regression_data.png")
    plt.show()
    
    # %%
    lh = jft.Gaussian(y, noise_std_inv=partial(op.mul, 1.0 / noise_std)).amend(nlr)
    
    # %%
    key, ki, ko = random.split(key, 3)
    
    delta = 1e-4
    samples_opt, st = jft.optimize_kl(
        lh,
        jft.Vector(lh.init(ki)),
        key=ko,
        n_total_iterations=5,
        n_samples=12,
        draw_linear_kwargs=dict(
            cg_kwargs=dict(absdelta=delta * jft.size(lh.domain) / 10.0),
        ),
        nonlinearly_update_kwargs=dict(
            minimize_kwargs=dict(xtol=delta, cg_kwargs=dict(name=None))
        ),
        kl_kwargs=dict(
            minimize_kwargs=dict(name="M", xtol=delta, cg_kwargs=dict(name=None))
        ),
        sample_mode="nonlinear_resample",
    )
    
    # %%
    x_p = np.linspace(x.min(), x.max(), 500)
    y_unnoised_samples = jax.vmap(partial(nlr, x=x_p))(samples_opt.samples)
    
    fig, ax = plt.subplots()
    ax.plot(x, y, color="dodgerblue", linestyle="None", marker=".", markersize=8)
    ax.plot(x_p, y_unnoised_samples.mean(axis=0), color="black")
    qs = (0.16, 0.84)
    ax.fill_between(
        x_p, *np.quantile(y_unnoised_samples, qs, axis=0), color="gray", alpha=0.3
    )
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    fig.tight_layout()
    fig.savefig("nonlinear_regression_posterior.png", dpi=400)
    plt.show()