diff --git a/demos/re/geomap.py b/demos/re/geomap.py deleted file mode 100644 index b36fc075c2e58d0bc21fa638bc40a7e611762127..0000000000000000000000000000000000000000 --- a/demos/re/geomap.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright(C) 2013-2021 Max-Planck-Society -# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause - -from functools import partial -import sys - -from jax import numpy as jnp -from jax import random -from jax import jit -import jax -from jax import random -import matplotlib.pyplot as plt - -import nifty8.re as jft - -jax.config.update("jax_enable_x64", True) - - -# %% -def lanczos_logdet( - mat, - v, - order: int, -): - """Computes a stochastic estimate of the log-determinate of the Lanczos - decomposed matrix. This is not the same as applying the stochastic Lanczos - quadrature algorithm as it estimates the log-determinate for the - decomposition only. - """ - mat = mat.__matmul__ if not hasattr(mat, "__call__") else mat - - tridiag, vecs = jft.lanczos.lanczos_tridiag(mat, v, order=order) - eig_vals = jnp.linalg.eigvalsh(tridiag) - return jnp.log(eig_vals).sum(), vecs - - -def _metric_sample( - hamiltonian: jft.StandardHamiltonian, - primals, - key, -): - if not isinstance(hamiltonian, jft.StandardHamiltonian): - te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'" - raise TypeError(te) - - subkey_nll, subkey_prr = random.split(key, 2) - nll_smpl = jft.kl.sample_likelihood( - hamiltonian.likelihood, primals, key=subkey_nll - ) - prr_inv_metric_smpl = jft.random_like(key=subkey_prr, primals=primals) - # One may transform any metric sample to a sample of the inverse - # metric by simply applying the inverse metric to it - prr_smpl = prr_inv_metric_smpl - met_smpl = nll_smpl + prr_smpl - return met_smpl, prr_smpl - - -def geomap( - hamiltonian: jft.StandardHamiltonian, - order: int, - key, - sample_orthonormally=True -): - from jax import flatten_util - - def geomap_energy(pos, return_aux=False): - p, unflatten = flatten_util.ravel_pytree(pos) - - def mat(x): - # Hack to stomp arbitrary objects into a 1D array - o, _ = flatten_util.ravel_pytree( - hamiltonian.metric(pos, unflatten(x)) - ) - return o - - probe, smpl = _metric_sample(hamiltonian, pos, key) - probe = flatten_util.ravel_pytree(probe)[0] - smpl = flatten_util.ravel_pytree(smpl)[0] - - logdet, vecs = lanczos_logdet(mat, probe, order) - - if not sample_orthonormally: - energy = hamiltonian(pos) - smpl_orig, smpl = None, None - else: - #smpl = random.normal(smpl_key, p.shape) - smpl_orig = unflatten(smpl.copy()) - # TODO: Pull into new lanczos method which computes orthoganlized smpls - # for vecs - ortho_smpl = vecs @ smpl - # One could add an additional `jnp.linalg.inv(vecs @ vecs.T)` in - # between the vecs to ensure proper projection - # ortho_smpl = jnp.linalg.inv(vecs @ vecs.T) @ ortho_smpl - ortho_smpl = vecs.T @ ortho_smpl - smpl -= ortho_smpl - smpl = unflatten(smpl) - - # GeoMAP requires the sample to be mirrored as to perform MAP along - # the subspace in the (near) linear regime. With samples, the - # solution is not only much less noisy in this regime but is - # actually the true posterior. - energy = 0.5 * (hamiltonian(pos + smpl) + hamiltonian(pos - smpl)) - - energy += 0.5 * logdet - if return_aux: - return energy, (smpl_orig, smpl) - return energy - - return geomap_energy - - -# %% -def hartley(p, axes=None): - from jax.numpy import fft - - tmp = fft.fftn(p, axes) - return tmp.real + tmp.imag - - -seed = 42 -key = random.PRNGKey(seed) - -dims = (1024, ) - -absdelta = 1e-4 * jnp.prod(jnp.array(dims)) - -cf = {"loglogavgslope": 2.} -loglogslope = cf["loglogavgslope"] -power_spectrum = lambda k: 1. / (k**loglogslope + 1.) - -modes = jnp.arange((dims[0] / 2) + 1., dtype=float) -harmonic_power = power_spectrum(modes) -# Every mode appears exactly two times, first ascending then descending -# Save a little on the computational side by mirroring the ascending part -harmonic_power = jnp.concatenate((harmonic_power, harmonic_power[-2:0:-1])) - -# Specify the model -correlated_field = jft.Model( - lambda x: hartley(harmonic_power * x), domain=jft.ShapeWithDtype(dims) -) -signal_response = lambda x: correlated_field(x) - -noise_cov = lambda x: 0.1**2 * x -noise_cov_inv = lambda x: 0.1**-2 * x - -# Create synthetic data -key, subkey = random.split(key) -pos_truth = jft.random_like(subkey, correlated_field.domain) -signal_response_truth = signal_response(pos_truth) -key, subkey = random.split(key) -noise_truth = jnp.sqrt(noise_cov(jnp.ones(dims)) - ) * random.normal(shape=dims, key=key) -data = signal_response_truth + noise_truth - -nll = jft.Gaussian(data, noise_cov_inv) @ signal_response -ham = jft.StandardHamiltonian(likelihood=nll).jit() - -plt.plot(jnp.array([signal_response_truth, data]).T, label=("truth", "data")) -plt.legend() -plt.show() - -# %% - -key, subkey, subkey_geomap = random.split(key, 3) -pos_init = jft.random_like(subkey, correlated_field.domain) -pos = 1e-2 * pos_init.copy() - -# %% -print("!!! HAM", ham(pos)) -print("!!! metric", ham.metric(pos, pos) @ pos) -# This is 50 times slower in compile time than ham.metric -geomap_order = 40 -geomap_energy = geomap( - ham, geomap_order, subkey_geomap, sample_orthonormally=True -) - -geomap_energy = jax.jit(geomap_energy, static_argnames=("return_aux", )) -print("!!! geomap_energy", geomap_energy(pos)) - -# %% -pos = 1e-2 * pos_init.copy() - -opt_state_geomap = jft.minimize( - geomap_energy, - pos, - method="newton-cg", - options={ - "name": "N", - "maxiter": 30, - "cg_kwargs": { - "name": None - }, - } -) - -# %% -_, (prr_smpl, ortho_smpl) = geomap_energy(opt_state_geomap.x, return_aux=True) - -plt.plot(prr_smpl, label="prior sample", alpha=0.7) -plt.plot(ortho_smpl, label="ortho sample", alpha=0.7) -plt.plot(jnp.abs(prr_smpl - ortho_smpl), label="abs diff", alpha=0.3) -plt.legend() -plt.show() - -# %% -smpls_by_order = [] -for i in range(1, geomap_order): - _, (_, s) = geomap(ham, i, subkey_geomap, sample_orthonormally=True)( - opt_state_geomap.x, return_aux=True - ) - smpls_by_order += [s] - -smpls_by_order = jnp.array(smpls_by_order) -# %% -fig, axs = plt.subplots(2, 1, sharex=True) -d = jnp.diff(smpls_by_order, axis=0) -axs.flat[0].plot( - smpls_by_order.T, label=jnp.arange(1, geomap_order), alpha=0.3, marker="." -) -axs.flat[0].axhline(0., color="red") -axs.flat[0].legend() -axs.flat[1].plot( - d.T, label=jnp.arange(1, geomap_order - 1), alpha=0.3, marker="." -) -axs.flat[1].axhline(0., color="red") -axs.flat[1].legend() -plt.show() - -# %% -plt.plot( - jnp.array( - [ - signal_response_truth, - data, - signal_response(opt_state_geomap.x), - signal_response(opt_state_geomap.x + ortho_smpl), - ] - ).T, - label=("truth", "data", "rec", "rec + smpl") -) -plt.legend() -plt.show() - -# %% -n_samples = 1 -n_newton_iterations = 10 -n_mgvi_iterations = 6 - -ham_vg = jit(jft.mean_value_and_grad(ham)) -ham_metric = jit(jft.mean_metric(ham.metric)) -MetricKL = jit( - partial(jft.MetricKL, ham), - static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") -) - -# %% -pos = 1e-2 * pos_init.copy() - -# Minimize the potential -for i in range(n_mgvi_iterations): - print(f"MGVI Iteration {i}", file=sys.stderr) - print("Sampling...", file=sys.stderr) - key, subkey = random.split(key, 2) - samples = MetricKL( - pos, - n_samples=n_samples, - key=subkey, - mirror_samples=False, - linear_sampling_kwargs={ - "absdelta": absdelta / 10., - "maxiter": geomap_order - }, - # linear_sampling_name="S", - ) - - print("Minimizing...", file=sys.stderr) - opt_state_mgvi = jft.minimize( - None, - pos, - method="newton-cg", - options={ - "fun_and_grad": partial(ham_vg, primals_samples=samples), - "hessp": partial(ham_metric, primals_samples=samples), - "absdelta": absdelta, - "maxiter": n_newton_iterations - } - ) - pos = opt_state_mgvi.x - msg = f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}" - print(msg, file=sys.stderr) - -# %% -plt.plot( - jnp.array( - [ - signal_response_truth, - data, - signal_response(opt_state_geomap.x), - signal_response(opt_state_mgvi.x), - *samples.at(opt_state_mgvi.x).apply(signal_response), - ] - ).T, - label=( - "truth", - "data", - "rec geomap", - "rec mgvi", - ) + ("smpls", ) * len(samples) -) -plt.legend() -plt.show()