diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index bdec5dc6bbc22f289909ca5695b73c46ddbaace8..60c223c9c14597443b2b430c750752ff1b6a8e9c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -142,18 +142,26 @@ run_getting_started_mf: - 'getting_started_mf_results' - '*.png' +run_getting_started_nifty2jax: + stage: demo_runs + script: + - python3 demos/getting_started_6_nifty2jax.py + artifacts: + paths: + - '*.png' + run_getting_density: stage: demo_runs script: - - python3 demos/getting_started_density.py + - python3 demos/more/density_estimation.py artifacts: paths: - '*.png' -run_getting_started_model_comparison: +run_model_comparison: stage: demo_runs script: - - python3 demos/getting_started_model_comparison.py + - python3 demos/more/model_comparison.py artifacts: paths: - '*.png' @@ -161,7 +169,7 @@ run_getting_started_model_comparison: run_bernoulli: stage: demo_runs script: - - python3 demos/bernoulli_demo.py + - python3 demos/more/bernoulli_map.py artifacts: paths: - '*.png' @@ -169,7 +177,7 @@ run_bernoulli: run_curve_fitting: stage: demo_runs script: - - python3 demos/polynomial_fit.py + - python3 demos/more/polynomial_fit.py artifacts: paths: - '*.png' @@ -177,9 +185,65 @@ run_curve_fitting: run_visual_vi: stage: demo_runs script: - - python3 demos/variational_inference_visualized.py + - python3 demos/more/variational_inference_visualized.py run_meanfield: stage: demo_runs script: - - python3 demos/parametric_variational_inference.py + - python3 demos/more/parametric_variational_inference.py + +run_demo_categorical_L1: + stage: demo_runs + script: + - python3 demos/re/categorical_L1.py + artifacts: + paths: + - '*.png' + +run_demo_cf_w_known_spectrum: + stage: demo_runs + script: + - python3 demos/re/correlated_field_w_known_spectrum.py + artifacts: + paths: + - '*.png' + +run_demo_cf_w_unknown_spectrum: + stage: demo_runs + script: + - python3 demos/re/correlated_field_w_unknown_spectrum.py + artifacts: + paths: + - '*.png' + +run_demo_cf_w_unknown_factorizing_spectra: + stage: demo_runs + script: + - python3 demos/re/correlated_field_w_unknown_factorizing_spectra.py + artifacts: + paths: + - '*.png' + +run_demo_nifty_to_jifty: + stage: demo_runs + script: + - python3 demos/re/nifty_to_jifty.py + artifacts: + paths: + - '*.png' + +run_demo_banana: + stage: demo_runs + script: + - python3 demos/re/banana.py + artifacts: + paths: + - '*.png' + +run_demo_banana_w_reg: + stage: demo_runs + script: + - python3 demos/re/banana_w_reg.py + artifacts: + paths: + - '*.png' diff --git a/demos/getting_started_6_nifty2jax.py b/demos/getting_started_6_nifty2jax.py new file mode 100644 index 0000000000000000000000000000000000000000..1a1a6b2e5795ee7c90b41dbc9bff9627daf7c49a --- /dev/null +++ b/demos/getting_started_6_nifty2jax.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +# %% [markdown] +# ## What Is This All About? +# +# * Short introduction in how to port code from NIFTy to JAX + NIFTY (jifty) +# * How to get the JAX expression for a NIFTy operator +# * How to minimize in jifty +# * Benchmark NIFTy vs jifty + +# %% +from collections import namedtuple +from functools import partial +import sys + +from jax import jit, value_and_grad +from jax import random +from jax import numpy as jnp +from jax.config import config as jax_config +from jax.tree_util import tree_map +import matplotlib.pyplot as plt +import numpy as np + +import nifty8 as ift +import nifty8.re as jft + +jax_config.update("jax_enable_x64", True) +# jax_config.update('jax_log_compiles', True) + +# %% +filename = "getting_started_nifty2jax{}.png" + +position_space = ift.RGSpace([512, 512]) +cfm_kwargs = { + 'offset_mean': -2., + 'offset_std': (1e-5, 1e-6), + 'fluctuations': (2., 0.2), # Amplitude of field fluctuations + 'loglogavgslope': (-4., 1), # Exponent of power law power spectrum + # Amplitude of integrated Wiener process on top of power law power spectrum + 'flexibility': (8e-1, 1e-1), + 'asperity': (3e-1, 1e-3) # Ragged-ness of integrated Wiener process +} + +correlated_field_nft = ift.SimpleCorrelatedField(position_space, **cfm_kwargs) +pow_spec_nft = correlated_field_nft.power_spectrum + +signal_nft = correlated_field_nft.exp() +response_nft = ift.GeometryRemover(signal_nft.target) +signal_response_nft = response_nft(signal_nft) + +# %% [markdown] +# ## From NIFTy to JAX + NIFTy +# +# By now, we built a beautiful and very complicated forward model. However, +# instead of using vanilla NumPy (i.e. using plain NIFTy), we want to compile +# the forward pass with JAX. + +# Note, JAX + NIFTy does not have the concept of domains. Though, it still +# needs to know how large the parameter space is. This can either be provided +# via an initializer or via a pytree containing the shapes and dtypes. Thus, in +# addition to extracting the JAX call, we also need to extract the parameter +# space on which this call should act. + +# %% +pow_spec = pow_spec_nft.jax_expr +signal = signal_nft.jax_expr +# Convenience method to get JAX expression and domain +signal_response = ift.nifty2jax.convert(signal_response_nft, float) + +noise_cov = 0.5**2 + +# %% +key = random.PRNGKey(42) + +key, sk = random.split(key) +synth_pos = jft.random_like(sk, signal_response) +data = synth_signal_response = signal_response(synth_pos) +data += jnp.sqrt(noise_cov) * random.normal(sk, shape=data.shape) + +fig, axs = plt.subplots(1, 2, figsize=(8, 4)) +im = axs.flat[0].imshow(synth_signal_response) +fig.colorbar(im, ax=axs.flat[0]) +im = axs.flat[1].imshow(data) +fig.colorbar(im, ax=axs.flat[1]) +fig.tight_layout() +plt.show() + +# %% +lh = jft.Gaussian(data, noise_cov_inv=lambda x: x / noise_cov) @ signal_response +ham = jft.StandardHamiltonian(likelihood=lh).jit() +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) + +key, subkey = random.split(key) +pos = pos_init = 1e-2 * jft.random_like(subkey, signal_response) + +# %% [markdown] +# Let's do a simple MGVI minimization. Note, while this might look very similar +# to plain NIFTy, the convergence criteria and various implementation details +# are very different. Thus, timing the minimization and comparing it to NIFTy +# most probably leads to very screwed results. It is best to only compare a +# single value-and-gradient call in both implementations for the purpose of +# creating a benchmark. + +# %% +n_mgvi_iterations = 10 +n_samples = 2 +absdelta = 0.1 +n_newton_iterations = 15 + +# Minimize the potential +key, *sk = random.split(key, 1 + n_mgvi_iterations) +for i, subkey in enumerate(sk): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + mg_samples = MetricKL( + pos, + n_samples=n_samples, + key=subkey, + mirror_samples=True, + linear_sampling_name=None, + linear_sampling_kwargs={"absdelta": absdelta / 10.} + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=mg_samples), + "hessp": partial(ham_metric, primals_samples=mg_samples), + "absdelta": absdelta, + "maxiter": n_newton_iterations + } + ) + pos = opt_state.x + msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +# %% +# The minimization is done now and we can have a look at the result. +fig, axs = plt.subplots(1, 3, figsize=(12, 4)) +im = axs.flat[0].imshow(synth_signal_response) +fig.colorbar(im, ax=axs.flat[0]) +im = axs.flat[1].imshow(data) +fig.colorbar(im, ax=axs.flat[1]) +sr_pm = mg_samples.at(pos).mean(signal_response) +im = axs.flat[2].imshow(sr_pm) +fig.colorbar(im, ax=axs.flat[2]) +fig.tight_layout() +plt.show() + +# %% [markdown] +# Awesome! We have seen now how a model can be translated to JAX. By doing so +# we were able to use such convenient transformation like `jit` and +# `value_and_grad` from JAX. Thus, we can start using higher order derivatives +# and other useful JAX features like `vmap` and `pmap`. Last but certainly not +# least, we can now also let our code run on the GPU. + +# %% [markdown] +# ## Performance +# +# The driving force behind all of this is of course speed! So let's validate +# that translating the model to JAX actually is faster. + +# %% +Timed = namedtuple("Timed", ("time", "number"), rename=True) + + +def timeit(stmt, setup=lambda: None, number=None): + import timeit + + if number is None: + number, _ = timeit.Timer(stmt).autorange() + + setup() + t = timeit.timeit(stmt, number=number) / number + return Timed(time=t, number=number) + + +r = jft.random_like(random.PRNGKey(54), signal_response) + +r_nft = ift.makeField(signal_response_nft.domain, r.val) +data_nft = ift.makeField(signal_response_nft.target, data) +lh_nft = ift.GaussianEnergy( + data_nft, + inverse_covariance=ift.ScalingOperator(data_nft.domain, 1. / noise_cov) +) @ signal_response_nft +ham_nft = ift.StandardHamiltonian(lh_nft) + +_ = ham(r) # Warm-Up +t = timeit(lambda: ham(r).block_until_ready()) +t_nft = timeit(lambda: ham_nft(r_nft)) + +print(f"W/ JAX :: {t}") +print(f"W/O JAX :: {t_nft}") + +# %% +# For about 2e+5 #parameters the FFT starts to dominate in the computation and +# NumPy-based NIFTy is about as fast as JAX-based NIFTy. Thus, we should not +# have expected to gain much performance for our model at hand. + +# So far so good but are we really sure that this is doing the same thing. To +# validate the result of our model in JAX, let's transfer our synthetic +# position to plain NIFTy and run the model there again. + +sp = ift.makeField(signal_response_nft.domain, synth_pos.val) +np.testing.assert_allclose( + signal_response_nft(sp).val, signal_response(synth_pos) +) + +# %% [markdown] +# For smaller models or models where the FFT does not dominate JAX-based NIFTy +# should always have an edge over NumPy based NIFTy. The difference in +# performance can range from only a couple of double digit percentages for +# \approx 1e+5 #parameters to many orders of magnitudes. For example with 65536 +# #parameters JAX-based NIFTy should be 2-3 times faster. + +# We can show this more explicitly with a proper benchmark. In the following we +# will instantiate models of various shapes and time the JAX version against +# the NumPy version. Instead of testing solely a single forward pass, we will +# compare a full evaluation of the model and its gradient. + + +# %% +def get_lognormal_model(shapes, cfm_kwargs, data_key, noise_cov=0.5**2): + import warnings + + position_space = ift.RGSpace(shapes) + + with warnings.catch_warnings(): + warnings.filterwarnings( + action="ignore", category=UserWarning, message="no JAX" + ) + correlated_field_nft = ift.SimpleCorrelatedField( + position_space, **cfm_kwargs + ) + signal_nft = correlated_field_nft.exp() + response_nft = ift.GeometryRemover(signal_nft.target) + signal_response_nft = response_nft(signal_nft) + + signal_response = ift.nifty2jax.convert(signal_response_nft, float) + + sk_signal, sk_noise = random.split(data_key) + synth_pos = jft.random_like(sk_signal, signal_response) + data = signal_response(synth_pos) + data += jnp.sqrt(noise_cov) * random.normal(sk_noise, shape=data.shape) + + noise_cov_inv = 1. / noise_cov + noise_std_inv = jnp.sqrt(noise_cov_inv) + lh = jft.Gaussian( + data, + noise_cov_inv=lambda x: noise_cov_inv * x, + noise_std_inv=lambda x: noise_std_inv * x + ) @ signal_response + ham = jft.StandardHamiltonian(likelihood=lh) + ham_vg = value_and_grad(ham) + + with warnings.catch_warnings(): + warnings.filterwarnings( + action="ignore", category=UserWarning, message="no JAX" + ) + data_nft = ift.makeField(signal_response_nft.target, data) + noise_cov_inv_nft = ift.ScalingOperator(data_nft.domain, 1. / noise_cov) + lh_nft = ift.GaussianEnergy( + data_nft, inverse_covariance=noise_cov_inv_nft + ) @ signal_response_nft + ham_nft = ift.StandardHamiltonian(lh_nft) + + def ham_vg_nft(x): + x = x.val if isinstance(x, jft.Field) else x + x = ift.makeField(ham_nft.domain, x) + x = ift.Linearization.make_var(x) + with warnings.catch_warnings(): + warnings.filterwarnings( + action="ignore", category=UserWarning, message="no JAX" + ) + res = ham_nft(x) + one_nft = ift.Field(ift.DomainTuple.make(()), 1.) + bwd = res.jac.adjoint_times(one_nft) + return (res.val.val, bwd.val) + + aux = { + "synthetic_position": synth_pos, + "hamiltonian_nft": ham_nft, + "hamiltonian": ham, + "signal_response_nft": signal_response_nft, + "signal_response": signal_response, + } + return ham_vg, ham_vg_nft, aux + + +get_ln_mod = partial( + get_lognormal_model, cfm_kwargs=cfm_kwargs, data_key=key, noise_cov=0.5**2 +) + +dimensions_to_test = [ + (256, ), (512, ), (1024, ), (256**2, ), (512**2, ), (128, 128), (256, 256), + (512, 512), (1024, 1024), (2048, 2048) +] +for dims in dimensions_to_test: + h, h_nft, aux = get_ln_mod(dims) + r = aux["synthetic_position"] + h = jit(h) + _ = h(r) # Warm-Up + + np.testing.assert_allclose(h(r)[0], h_nft(r)[0]) + ift.myassert(all(tree_map(np.allclose, h(r)[1].val, h_nft(r)[1]).values())) + ti = timeit(lambda: h(r)[0].block_until_ready()) + ti_n = timeit(lambda: h_nft(r)) + + print( + f"Shape {str(dims):>16s}" + f" :: JAX {ti.time:4.2e}" + f" :: NIFTy {ti_n.time:4.2e}" + f" ;; ({ti.number:6d}, {ti_n.number:<6d} loops respectively)" + ) + +# %% [markdown] +# | Shape | JAX | NIFTy | Loops respectively | +# |:-----------------------|:-------------|:---------------| -----------------------------------:| +# | Shape (256,) | JAX 2.58e-05 | NIFTy 6.96e-03 | ( 10000, 50 loops respectively) | +# | Shape (512,) | JAX 3.90e-05 | NIFTy 7.14e-03 | ( 10000, 50 loops respectively) | +# | Shape (1024,) | JAX 6.33e-05 | NIFTy 6.97e-03 | ( 5000, 50 loops respectively) | +# | Shape (65536,) | JAX 5.41e-03 | NIFTy 1.42e-02 | ( 50, 20 loops respectively) | +# | Shape (262144,) | JAX 2.72e-02 | NIFTy 4.41e-02 | ( 10, 5 loops respectively) | +# | Shape (128, 128) | JAX 5.07e-04 | NIFTy 7.00e-03 | ( 500, 50 loops respectively) | +# | Shape (256, 256) | JAX 3.74e-03 | NIFTy 1.01e-02 | ( 100, 20 loops respectively) | +# | Shape (512, 512) | JAX 1.53e-02 | NIFTy 2.33e-02 | ( 20, 10 loops respectively) | +# | Shape (1024, 1024) | JAX 7.80e-02 | NIFTy 7.72e-02 | ( 5, 5 loops respectively) | +# | Shape (2048, 2048) | JAX 3.21e-01 | NIFTy 3.52e-01 | ( 1, 1 loops respectively) | + +# For small problems JAX-based NIFTy is significantly faster than the NumPy +# based one. For really small problems it is more than 200 times faster. This +# is because the overhead from python can be significantly reduced with JAX +# since most of the heavy-lifting happens without going back to python. + +# Notice, how above a certain threshold, here 2e+5, the NumPy-based NIFTy and +# JAX-bassed NIFTy start to perform similarly well because the performance of +# the FFT is the sole bottle neck. diff --git a/demos/bernoulli_demo.py b/demos/more/bernoulli_map.py similarity index 100% rename from demos/bernoulli_demo.py rename to demos/more/bernoulli_map.py diff --git a/demos/misc/convolution.py b/demos/more/convolution.py similarity index 100% rename from demos/misc/convolution.py rename to demos/more/convolution.py diff --git a/demos/getting_started_density.py b/demos/more/density_estimation.py similarity index 100% rename from demos/getting_started_density.py rename to demos/more/density_estimation.py diff --git a/demos/getting_started_model_comparison.py b/demos/more/model_comparison.py similarity index 100% rename from demos/getting_started_model_comparison.py rename to demos/more/model_comparison.py diff --git a/demos/parametric_variational_inference.py b/demos/more/parametric_variational_inference.py similarity index 100% rename from demos/parametric_variational_inference.py rename to demos/more/parametric_variational_inference.py diff --git a/demos/polynomial_fit.py b/demos/more/polynomial_fit.py similarity index 100% rename from demos/polynomial_fit.py rename to demos/more/polynomial_fit.py diff --git a/demos/variational_inference_visualized.py b/demos/more/variational_inference_visualized.py similarity index 100% rename from demos/variational_inference_visualized.py rename to demos/more/variational_inference_visualized.py diff --git a/demos/re/banana.py b/demos/re/banana.py new file mode 100644 index 0000000000000000000000000000000000000000..c496d29ef001e2a3b420e906b66d4e3089c07eee --- /dev/null +++ b/demos/re/banana.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys + +from jax import numpy as jnp +from jax import lax, random +from jax import jit +from jax.config import config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + +seed = 42 +key = random.PRNGKey(seed) + + +# %% +def cartesian_product(arrays, out=None): + import numpy as np + + # Generalized N-dimensional products + arrays = [np.asarray(x) for x in arrays] + la = len(arrays) + dtype = np.find_common_type([a.dtype for a in arrays], []) + if out is None: + out = np.empty([len(a) for a in arrays] + [la], dtype=dtype) + for i, a in enumerate(np.ix_(*arrays)): + out[..., i] = a + return out.reshape(-1, la) + + +def banana_helper_phi_b(b, x): + return jnp.array([x[0], x[1] + b * x[0]**2 - 100 * b]) + + +def sample_nonstandard_hamiltonian( + likelihood, primals, key, cg=jft.static_cg, cg_name=None, cg_kwargs=None +): + if not isinstance(likelihood, jft.Likelihood): + te = f"`likelihood` of invalid type; got '{type(likelihood)}'" + raise TypeError(te) + from jax.tree_util import Partial + + cg_kwargs = cg_kwargs if cg_kwargs is not None else {} + cg_kwargs = {"name": cg_name, **cg_kwargs} + + white_sample = jft.random_like( + key, likelihood.left_sqrt_metric_tangents_shape + ) + met_smpl = likelihood.left_sqrt_metric(primals, white_sample) + inv_metric_at_p = partial( + cg, Partial(likelihood.metric, primals), **cg_kwargs + ) + signal_smpl = inv_metric_at_p(met_smpl)[0] + return signal_smpl + + +def NonStandardMetricKL( + likelihood, + primals, + n_samples, + key, + mirror_samples: bool = True, + linear_sampling_cg=jft.static_cg, + linear_sampling_name=None, + linear_sampling_kwargs=None, +): + from jax.tree_util import Partial + + if not isinstance(likelihood, jft.Likelihood): + te = f"`likelihood` of invalid type; got '{type(likelihood)}'" + raise TypeError(te) + + draw = Partial( + sample_nonstandard_hamiltonian, + likelihood=likelihood, + primals=primals, + cg=linear_sampling_cg, + cg_name=linear_sampling_name, + cg_kwargs=linear_sampling_kwargs, + ) + subkeys = random.split(key, n_samples) + samples_stack = lax.map(lambda k: draw(key=k), subkeys) + + return jft.kl.SampleIter( + mean=primals, + samples=jft.unstack(samples_stack), + linearly_mirror_samples=mirror_samples + ) + + +# %% +b = 0.1 + +signal_response = partial(banana_helper_phi_b, b) +nll = jft.Gaussian( + jnp.zeros(2), lambda x: x / jnp.array([100., 1.]) +) @ signal_response + +ham = nll +ham = ham.jit() +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) + +# %% +n_mgvi_iterations = 30 +n_samples = [1] * (n_mgvi_iterations - 10) + [2] * 5 + [3, 3, 10, 10, 100] +n_newton_iterations = [7] * (n_mgvi_iterations - 10) + [10] * 6 + 4 * [25] +absdelta = 1e-12 + +initial_position = jnp.array([1., 1.]) +mkl_pos = 1e-2 * initial_position + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + samples = NonStandardMetricKL( + ham, + mkl_pos, + n_samples[i], + key=subkey, + mirror_samples=True, + linear_sampling_kwargs={"miniter": 0}, + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=mkl_pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=samples), + "hessp": partial(ham_metric, primals_samples=samples), + "energy_reduction_factor": None, + "absdelta": absdelta, + "maxiter": n_newton_iterations[i], + "cg_kwargs": { + "miniter": 0 + }, + "name": "N", + } + ) + mkl_pos = opt_state.x + msg = f"Post MGVI Iteration {i}: Energy {samples.at(mkl_pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +# %% +b_space_smpls = jnp.array(tuple(samples.at(mkl_pos))) + +n_pix_sqrt = 1000 +x = jnp.linspace(-10.0, 10.0, n_pix_sqrt, endpoint=True) +y = jnp.linspace(2.0, 17.0, n_pix_sqrt, endpoint=True) +X, Y = jnp.meshgrid(x, y) +XY = jnp.array([X, Y]).T +xy = XY.reshape((XY.shape[0] * XY.shape[1], 2)) +es = jnp.exp(-lax.map(ham, xy)).reshape(XY.shape[:2]).T + +fig, ax = plt.subplots() +contour = ax.contour(X, Y, es) +ax.clabel(contour, inline=True, fontsize=10) +ax.scatter(*b_space_smpls.T) +ax.plot(*mkl_pos, "rx") +fig.tight_layout() +fig.savefig("banana_mgvi_wo_regularization.png", dpi=400) +plt.close() diff --git a/demos/re/banana_w_reg.py b/demos/re/banana_w_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..f93f7e9c203a7396fa267519d04f77e0877cea13 --- /dev/null +++ b/demos/re/banana_w_reg.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +# %% +from functools import partial +import sys + +from jax import numpy as jnp +from jax import lax, random +from jax import jit, value_and_grad +from jax.config import config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + +seed = 42 +key = random.PRNGKey(seed) + + +# %% +def cartesian_product(arrays, out=None): + import numpy as np + + # Generalized N-dimensional products + arrays = [np.asarray(x) for x in arrays] + la = len(arrays) + dtype = np.find_common_type([a.dtype for a in arrays], []) + if out is None: + out = np.empty([len(a) for a in arrays] + [la], dtype=dtype) + for i, a in enumerate(np.ix_(*arrays)): + out[..., i] = a + return out.reshape(-1, la) + + +def banana_helper_phi_b(b, x): + return jnp.array([x[0], x[1] + b * x[0]**2 - 100 * b]) + + +# %% +b = 0.1 + +SCALE = 10. + +signal_response = lambda s: banana_helper_phi_b(b, SCALE * s) +nll = jft.Gaussian( + jnp.zeros(2), lambda x: x / jnp.array([100., 1.]) +) @ signal_response +nll = nll.jit() +nll_vg = jit(value_and_grad(nll)) + +ham = jft.StandardHamiltonian(nll) +ham = ham.jit() +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) + +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) +GeoMetricKL = partial(jft.GeoMetricKL, ham) + +# # %% +# # TODO: Stabilize inversion +# gkl_position = jnp.array([1.15995025, -0.35110244]) +# special_key = jnp.array([3269562362, 460782344], dtype=jnp.uint32) +# err = jft.geometrically_sample_standard_hamiltonian( +# key=special_key, +# hamiltonian=ham, +# primals=gkl_position, +# mirror_linear_sample=False, +# linear_sampling_name="SCG", +# linear_sampling_kwargs={"miniter": -1}, +# non_linear_sampling_name="S", +# non_linear_sampling_kwargs={ +# "cg_kwargs": { +# "miniter": -1 +# }, +# "maxiter": 20, +# } +# ) + +# %% # MGVI +n_mgvi_iterations = 30 +n_samples = [1] * (n_mgvi_iterations - 2) + [2] + [100] +n_newton_iterations = [7] * (n_mgvi_iterations - 10) + [10] * 6 + 4 * [25] +absdelta = 1e-10 + +initial_position = jnp.array([1., 1.]) +mkl_pos = 1e-2 * initial_position + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + mg_samples = MetricKL( + mkl_pos, + n_samples=n_samples[i], + key=subkey, + mirror_samples=True, + linear_sampling_kwargs={"miniter": 0} + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=mkl_pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=mg_samples), + "hessp": partial(ham_metric, primals_samples=mg_samples), + "energy_reduction_factor": None, + "absdelta": absdelta, + "maxiter": n_newton_iterations[i], + "cg_kwargs": { + "miniter": 0, + "name": None + }, + "name": "N" + } + ) + mkl_pos = opt_state.x + print( + ( + f"Post MGVI Iteration {i}: Energy {mg_samples.at(mkl_pos).mean(ham):2.4e}" + f"; #NaNs {jnp.isnan(mkl_pos).sum()}" + ), + file=sys.stderr + ) + +# %% # geoVI +n_geovi_iterations = 15 +n_samples = [1] * (n_geovi_iterations - 2) + [2] + [100] +n_newton_iterations = [7] * (n_geovi_iterations - 10) + [10] * 6 + [25] * 4 +absdelta = 1e-10 + +initial_position = jnp.array([1., 1.]) +gkl_pos = 1e-2 * initial_position + +for i in range(n_geovi_iterations): + print(f"GeoVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + geo_samples = GeoMetricKL( + gkl_pos, + n_samples[i], + key=subkey, + mirror_samples=True, + linear_sampling_name=None, + linear_sampling_kwargs={"miniter": 0}, + non_linear_sampling_name=None, + non_linear_sampling_kwargs={ + "cg_kwargs": { + "miniter": 0, + "absdelta": None + }, + "maxiter": 20, + }, + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=gkl_pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=geo_samples), + "hessp": partial(ham_metric, primals_samples=geo_samples), + "energy_reduction_factor": None, + "absdelta": absdelta, + "maxiter": n_newton_iterations[i], + "cg_kwargs": { + "miniter": 0, + "name": None + }, + "name": "N", + } + ) + gkl_pos = opt_state.x + +# %% +absdelta = 1e-10 +opt_state = jft.minimize( + None, + x0=jnp.array([1., 1.]), + method="newton-cg", + options={ + "fun_and_grad": ham_vg, + "hessp": ham.metric, + "energy_reduction_factor": None, + "absdelta": absdelta, + "maxiter": 100, + "cg_kwargs": { + "miniter": 0, + "name": None + }, + "name": "MAP" + } +) +map_pos = opt_state.x +key, subkey = random.split(key, 2) +map_geo_samples = GeoMetricKL( + map_pos, + 100, + key=subkey, + mirror_samples=True, + linear_sampling_name=None, + linear_sampling_kwargs={"miniter": 0}, + non_linear_sampling_name=None, + non_linear_sampling_kwargs={ + "cg_kwargs": { + "miniter": 0 + }, + "maxiter": 20, + } +) + +# %% + +n_pix_sqrt = 1000 +x = jnp.linspace(-30 / SCALE, 30 / SCALE, n_pix_sqrt) +y = jnp.linspace(-15 / SCALE, 15 / SCALE, n_pix_sqrt) +X, Y = jnp.meshgrid(x, y) +XY = jnp.array([X, Y]).T +xy = XY.reshape((XY.shape[0] * XY.shape[1], 2)) +es = jnp.exp(-lax.map(ham, xy)).reshape(XY.shape[:2]).T + +fig, axs = plt.subplots(1, 3, figsize=(16, 9)) + +b_space_smpls = jnp.array(tuple(mg_samples.at(mkl_pos))) +contour = axs[0].contour(X, Y, es) +axs[0].clabel(contour, inline=True, fontsize=10) +axs[0].scatter(*b_space_smpls.T) +axs[0].plot(*mkl_pos, "rx") +axs[0].set_title("MGVI") + +b_space_smpls = jnp.array(tuple(geo_samples.at(gkl_pos))) +contour = axs[1].contour(X, Y, es) +axs[1].clabel(contour, inline=True, fontsize=10) +axs[1].scatter(*b_space_smpls.T, alpha=0.7) +axs[1].plot(*gkl_pos, "rx") +axs[1].set_title("GeoVI") + +b_space_smpls = jnp.array(tuple(map_geo_samples.at(map_pos))) +contour = axs[2].contour(X, Y, es) +axs[2].clabel(contour, inline=True, fontsize=10) +axs[2].scatter(*b_space_smpls.T, alpha=0.7) +axs[2].plot(*map_pos, "rx") +axs[2].set_title("MAP + GeoVI Samples") + +fig.tight_layout() +fig.savefig("banana_vi_w_regularization.png", dpi=400) +plt.close() diff --git a/demos/re/categorical_L1.py b/demos/re/categorical_L1.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbed1c54e9e265e63b4132129a2937264c9d3f7 --- /dev/null +++ b/demos/re/categorical_L1.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys + +from jax import numpy as jnp +from jax import random +from jax import jit, value_and_grad +from jax.config import config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + + +def build_model(predictors, targets, sh, alpha=1): + my_laplace_prior = jft.interpolate()(jft.laplace_prior(alpha)) + matrix = lambda x: my_laplace_prior(x).reshape(sh) + model = lambda x: jnp.matmul(predictors, matrix(x)) + lh = jft.Categorical(targets, axis=1) + return {"lh": lh @ model, "logits": model, "matrix": matrix} + + +seed = 42 +key = random.PRNGKey(seed) + +N_data = 1024 +N_categories = 10 +N_predictors = 3 + +n_mgvi_iterations = 5 +n_samples = 5 +mirror_samples = True +n_newton_iterations = 5 + +# Create synthetic data +mock_predictors = random.normal(shape=(N_data, N_predictors), key=key) +key, subkey = random.split(key) +model = build_model( + mock_predictors, jnp.zeros((N_data, 1), dtype=jnp.int32), + (N_predictors, N_categories) +) +latent_truth = random.normal(shape=(N_predictors * N_categories, ), key=subkey) +key, subkey = random.split(key) +matrix_truth = model["matrix"](latent_truth) +logits_truth = model["logits"](latent_truth) + +mock_targets = random.categorical(logits=logits_truth, key=subkey) +key, subkey = random.split(key) +mock_targets = mock_targets.reshape(N_data, 1) + +model = build_model(mock_predictors, mock_targets, (N_predictors, N_categories)) +ham = jft.StandardHamiltonian(likelihood=model["lh"]).jit() + +pos_init = .1 * random.normal(shape=(N_predictors * N_categories, ), key=subkey) +key, subkey = random.split(key) +pos = pos_init.copy() + +diff_to_truth = jnp.linalg.norm(model["matrix"](pos) - matrix_truth) +print(f"Initial diff to truth {diff_to_truth}", file=sys.stderr) + + +def energy(p, samps): + return jnp.mean(jnp.array([ham(p + s) for s in samps]), axis=0) + + +@jit +def metric(p, t, samps): + results = [ham.metric(p + s, t) for s in samps] + return jnp.mean(jnp.array(results), axis=0) + + +energy_vag = jit(value_and_grad(energy)) +draw = partial(jft.kl.sample_standard_hamiltonian, hamiltonian=ham) + +# Preform MGVI loop +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + key, *subkeys = random.split(key, 1 + n_samples) + samples = [] + samples = [draw(primals=pos, key=k) for k in subkeys] + + Evag = lambda p: energy_vag(p, samples) + met = lambda p, t: metric(p, t, samples) + opt_state = jft.minimize( + None, + x0=pos, + method="newton-cg", + options={ + "fun_and_grad": Evag, + "hessp": met, + "maxiter": n_newton_iterations + } + ) + pos = opt_state.x + diff_to_truth = jnp.linalg.norm(model["matrix"](pos) - matrix_truth) + print( + ( + f"Post MGVI Iteration {i}: Energy {Evag(pos)[0]:2.4e}" + f"; diff to truth {diff_to_truth}" + ), + file=sys.stderr + ) + +posterior_samps = [s + pos for s in samples] + +matrix_samps = jnp.array([model["matrix"](s) for s in posterior_samps]) +matrix_mean = jnp.mean(matrix_samps, axis=0) +matrix_std = jnp.std(matrix_samps, axis=0) +xx = jnp.linspace(-3.5, 3.5, 2) +plt.plot(xx, xx) +plt.errorbar( + matrix_truth.reshape(-1), + matrix_mean.reshape(-1), + yerr=matrix_std.reshape(-1), + fmt='o', + color="black" +) +plt.xlabel("truth") +plt.ylabel("inferred value") +plt.savefig("matrix_fit.png", dpi=400) +plt.close() diff --git a/demos/re/correlated_field_w_known_spectrum.py b/demos/re/correlated_field_w_known_spectrum.py new file mode 100644 index 0000000000000000000000000000000000000000..952572305fa2e642bf5dcc16cd1651312d313832 --- /dev/null +++ b/demos/re/correlated_field_w_known_spectrum.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys + +from jax import numpy as jnp +from jax import random +from jax import jit +from jax.config import config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + + +@jit +def cosine_similarity(x, y): + return jnp.dot(x, y) / (jnp.linalg.norm(x) * jnp.linalg.norm(y)) + + +def hartley(p, axes=None): + from jax.numpy import fft + + tmp = fft.fftn(p, axes) + return tmp.real + tmp.imag + + +seed = 42 +key = random.PRNGKey(seed) + +dims = (1024, ) + +n_mgvi_iterations = 3 +n_samples = 4 +n_newton_iterations = 5 +absdelta = 1e-4 * jnp.prod(jnp.array(dims)) + +cf = {"loglogavgslope": 2.} +loglogslope = cf["loglogavgslope"] +power_spectrum = lambda k: 1. / (k**loglogslope + 1.) + +modes = jnp.arange((dims[0] / 2) + 1., dtype=float) +harmonic_power = power_spectrum(modes) +# Every mode appears exactly two times, first ascending then descending +# Save a little on the computational side by mirroring the ascending part +harmonic_power = jnp.concatenate((harmonic_power, harmonic_power[-2:0:-1])) + +# Specify the model +correlated_field = lambda x: hartley(harmonic_power * x.val) +signal_response = lambda x: jnp.exp(1. + correlated_field(x)) +noise_cov = lambda x: 0.1**2 * x +noise_cov_inv = lambda x: 0.1**-2 * x + +# Create synthetic data +key, subkey = random.split(key) +pos_truth = jft.Field(random.normal(shape=dims, key=key)) +signal_response_truth = signal_response(pos_truth) +key, subkey = random.split(key) +noise_truth = jnp.sqrt(noise_cov(jnp.ones(dims)) + ) * random.normal(shape=dims, key=key) +data = signal_response_truth + noise_truth + +nll = jft.Gaussian(data, noise_cov_inv) @ signal_response +ham = jft.StandardHamiltonian(likelihood=nll).jit() + +key, subkey = random.split(key) +pos_init = random.normal(shape=dims, key=subkey) +pos = 1e-2 * jft.Field(pos_init) + +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + samples = MetricKL( + pos, + n_samples=n_samples, + key=subkey, + mirror_samples=True, + linear_sampling_kwargs={"absdelta": absdelta / 10.} + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=pos, + method="trust-ncg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=samples), + "hessp": partial(ham_metric, primals_samples=samples), + "initial_trust_radius": 1e+1, + "max_trust_radius": 1e+4, + "absdelta": absdelta, + "maxiter": n_newton_iterations, + "name": "N", + "subproblem_kwargs": { + "miniter": 6, + } + } + ) + # opt_state = jft.minimize( + # None, + # x0=pos, + # method="newton-cg", + # options={ + # "fun_and_grad": partial(ham_vg, primals_samples=samples), + # "hessp": partial(ham_metric, primals_samples=samples), + # "absdelta": absdelta, + # "maxiter": n_newton_iterations + # } + # ) + pos = opt_state.x + print( + ( + f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}" + f"; Cos-Sim {cosine_similarity(pos.val, pos_truth.val):2.3%}" + f"; #NaNs {jnp.isnan(pos.val).sum()}" + ), + file=sys.stderr + ) + +post_sr_mean = jft.mean(tuple(signal_response(s) for s in samples.at(pos))) +fig, ax = plt.subplots() +ax.plot(signal_response_truth, alpha=0.7, label="Signal") +ax.plot(noise_truth, alpha=0.7, label="Noise") +ax.plot(data, alpha=0.7, label="Data") +ax.plot(post_sr_mean, alpha=0.7, label="Reconstruction") +ax.legend() +fig.tight_layout() +fig.savefig("cf_w_known_spectrum.png", dpi=400) +plt.close() diff --git a/demos/re/correlated_field_w_unknown_factorizing_spectra.py b/demos/re/correlated_field_w_unknown_factorizing_spectra.py new file mode 100644 index 0000000000000000000000000000000000000000..76e246479bf954aa5703ad65073ed857585dff87 --- /dev/null +++ b/demos/re/correlated_field_w_unknown_factorizing_spectra.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys + +from jax import numpy as jnp +from jax import random +from jax import jit +from jax.config import config as jax_config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +jax_config.update("jax_enable_x64", True) + +seed = 42 +key = random.PRNGKey(seed) + +dims_ax1 = (128, ) +dims_ax2 = (256, ) + +n_mgvi_iterations = 3 +n_samples = 4 +n_newton_iterations = 10 +absdelta = 1e-4 * jnp.prod(jnp.array(dims_ax1 + dims_ax2)) + +cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)} +cf_fl = { + "fluctuations": (1e-1, 5e-3), + "loglogavgslope": (-1., 1e-2), + "flexibility": (1e+0, 5e-1), + "asperity": (5e-1, 1e-1), + "harmonic_domain_type": "Fourier" +} +cfm = jft.CorrelatedFieldMaker("cf") +cfm.set_amplitude_total_offset(**cf_zm) +d = 1. / dims_ax1[0] +cfm.add_fluctuations(dims_ax1, distances=d, **cf_fl, prefix="ax1") +d = 1. / dims_ax2[0] +cfm.add_fluctuations(dims_ax2, distances=d, **cf_fl, prefix="ax2") +correlated_field, ptree = cfm.finalize() + +signal_response = lambda x: correlated_field(x) +noise_cov = lambda x: 0.1**2 * x +noise_cov_inv = lambda x: 0.1**-2 * x + +# Create synthetic data +key, subkey = random.split(key) +pos_truth = jft.random_like(subkey, ptree) +signal_response_truth = signal_response(pos_truth) +key, subkey = random.split(key) +noise_truth = jnp.sqrt( + noise_cov(jnp.ones(signal_response_truth.shape)) +) * random.normal(shape=signal_response_truth.shape, key=key) +data = signal_response_truth + noise_truth + +nll = jft.Gaussian(data, noise_cov_inv) @ signal_response +ham = jft.StandardHamiltonian(likelihood=nll).jit() + +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) + +key, subkey = random.split(key) +pos_init = jft.random_like(subkey, ptree) +pos = 1e-2 * jft.Field(pos_init) + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + samples = MetricKL( + pos, + n_samples=n_samples, + key=subkey, + mirror_samples=True, + linear_sampling_kwargs={"absdelta": absdelta / 10.} + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=samples), + "hessp": partial(ham_metric, primals_samples=samples), + "absdelta": absdelta, + "maxiter": n_newton_iterations + } + ) + pos = opt_state.x + msg = f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +namps = cfm.get_normalized_amplitudes() +post_sr_mean = jft.mean(tuple(signal_response(s) for s in samples.at(pos))) +post_namps1_mean = jft.mean(tuple(namps[0](s)[1:] for s in samples.at(pos))) +post_namps2_mean = jft.mean(tuple(namps[1](s)[1:] for s in samples.at(pos))) +to_plot = [ + ("Signal", signal_response_truth, "im"), + ("Noise", noise_truth, "im"), + ("Data", data, "im"), + ("Reconstruction", post_sr_mean, "im"), + ("Ax1", (namps[0](pos_truth)[1:], post_namps1_mean), "loglog"), + ("Ax2", (namps[1](pos_truth)[1:], post_namps2_mean), "loglog"), +] +fig, axs = plt.subplots(2, 3, figsize=(16, 9)) +for ax, (title, field, tp) in zip(axs.flat, to_plot): + ax.set_title(title) + if tp == "im": + im = ax.imshow(field, cmap="inferno") + plt.colorbar(im, ax=ax, orientation="horizontal") + else: + ax_plot = ax.loglog if tp == "loglog" else ax.plot + field = field if isinstance(field, (tuple, list)) else (field, ) + for f in field: + ax_plot(f, alpha=0.7) +fig.tight_layout() +fig.savefig("cf_w_unknown_factorizing_spectra.png", dpi=400) +plt.close() diff --git a/demos/re/correlated_field_w_unknown_spectrum.py b/demos/re/correlated_field_w_unknown_spectrum.py new file mode 100644 index 0000000000000000000000000000000000000000..b5254b753e918b60a4f4fe71ef292d71eba6693c --- /dev/null +++ b/demos/re/correlated_field_w_unknown_spectrum.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys + +from jax import numpy as jnp +from jax import random +from jax import jit +from jax.config import config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + +seed = 42 +key = random.PRNGKey(seed) + +dims = (256, 256) + +n_mgvi_iterations = 3 +n_samples = 4 +n_newton_iterations = 10 +absdelta = 1e-4 * jnp.prod(jnp.array(dims)) + +cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)} +cf_fl = { + "fluctuations": (1e-1, 5e-3), + "loglogavgslope": (-1., 1e-2), + "flexibility": (1e+0, 5e-1), + "asperity": (5e-1, 5e-2), + "harmonic_domain_type": "Fourier" +} +cfm = jft.CorrelatedFieldMaker("cf") +cfm.set_amplitude_total_offset(**cf_zm) +cfm.add_fluctuations(dims, distances=1. / dims[0], **cf_fl, prefix="ax1") +correlated_field, ptree = cfm.finalize() + +signal_response = lambda x: jnp.exp(correlated_field(x)) +noise_cov = lambda x: 0.1**2 * x +noise_cov_inv = lambda x: 0.1**-2 * x + +# Create synthetic data +key, subkey = random.split(key) +pos_truth = jft.random_like(subkey, ptree) +signal_response_truth = signal_response(pos_truth) +key, subkey = random.split(key) +noise_truth = jnp.sqrt(noise_cov(jnp.ones(dims)) + ) * random.normal(shape=dims, key=key) +data = signal_response_truth + noise_truth + +nll = jft.Gaussian(data, noise_cov_inv) @ signal_response +ham = jft.StandardHamiltonian(likelihood=nll).jit() + +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) + +key, subkey = random.split(key) +pos_init = jft.random_like(subkey, ptree) +pos = 1e-2 * jft.Field(pos_init.copy()) + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + samples = jft.MetricKL( + ham, + pos, + n_samples=n_samples, + key=subkey, + mirror_samples=True, + linear_sampling_kwargs={"absdelta": absdelta / 10.} + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=samples), + "hessp": partial(ham_metric, primals_samples=samples), + "absdelta": absdelta, + "maxiter": n_newton_iterations + } + ) + pos = opt_state.x + msg = f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +namps = cfm.get_normalized_amplitudes() +post_sr_mean = jft.mean(tuple(signal_response(s) for s in samples.at(pos))) +post_a_mean = jft.mean(tuple(cfm.amplitude(s)[1:] for s in samples.at(pos))) +to_plot = [ + ("Signal", signal_response_truth, "im"), + ("Noise", noise_truth, "im"), + ("Data", data, "im"), + ("Reconstruction", post_sr_mean, "im"), + ("Ax1", (cfm.amplitude(pos_truth)[1:], post_a_mean), "loglog"), +] +fig, axs = plt.subplots(2, 3, figsize=(16, 9)) +for ax, (title, field, tp) in zip(axs.flat, to_plot): + ax.set_title(title) + if tp == "im": + im = ax.imshow(field, cmap="inferno") + plt.colorbar(im, ax=ax, orientation="horizontal") + else: + ax_plot = ax.loglog if tp == "loglog" else ax.plot + field = field if isinstance(field, (tuple, list)) else (field, ) + for f in field: + ax_plot(f, alpha=0.7) +fig.tight_layout() +fig.savefig("cf_w_unknown_spectrum.png", dpi=400) +plt.close() diff --git a/demos/re/graph_refine.py b/demos/re/graph_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..edc5e701cc134ee47a6058421d51ce7ca0b7f606 --- /dev/null +++ b/demos/re/graph_refine.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +import jax +from jax import numpy as jnp +from jax import random +import matplotlib.pyplot as plt +import numpy as np + +import nifty8.re as jft + + +def get_kernel(layer_weights, depth, n_samples=100): + xi = {"offset": 0., "layer_weights": layer_weights} + kernel = np.zeros(2**depth) + for _ in range(n_samples): + xi["excitations"] = jnp.array(rng.normal(size=(2**depth, ))) + r = fwd(xi) + for i in range(r.size): + kernel[i] += np.mean(r * np.roll(r, i)) + kernel /= len(n_samples) + return kernel + + +def fwd(xi): + offset = xi["offset"] + excitations = xi["excitations"] + layer_wgt = xi["layer_weights"] + + kernel = jnp.array([1., 2., 1.]) + kernel /= kernel.sum() + layers = [excitations] + while layers[-1].size > 1: + lvl = layers[-1] + if layers[-1].size > 2: + lvl = jnp.convolve(lvl, kernel, mode="same") + layers += [0.5 * lvl.reshape(-1, 2).sum(axis=1)] + if len(layers) != len(layer_wgt): + raise ValueError() + + field = offset + for d, (wgt, lvl) in enumerate(zip(layer_wgt, layers)): + field += wgt * jnp.repeat(lvl, 2**d) + + return field + + +# %% +rng = np.random.default_rng(42) +depth = 8 + +for _ in range(10): + layer_weights = jnp.array(rng.normal(size=(depth + 1, ))) + layer_weights = jnp.exp(0.1 * layer_weights) #lognomral + kernel = get_kernel(layer_weights, depth, n_samples=30) + plt.plot(kernel) +plt.show() + +# %% +spec = np.fft.fft(kernel) +plt.plot(spec) +plt.yscale("log") +plt.xscale("log") +plt.show() + +# %% + +rng = np.random.default_rng(42) +depth = 12 +xi = { + "offset": 0., + "excitations": jnp.array(rng.normal(size=(2**depth, ))), + # "layer_weights": jnp.exp(+jnp.array(rng.normal(size=(depth + 1, )))), + "layer_weights": jnp.exp(0.1 * jnp.arange(depth + 1, dtype=float)), + # "layer_weights": jnp.array([1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,]), +} + +plt.plot(xi["excitations"], label="excitations", alpha=0.6) +plt.plot(fwd(xi), label="Forward Model", alpha=0.6) +plt.legend() +plt.show() + +# %% +cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)} +cf_fl = { + "fluctuations": (1e-1, 5e-3), + "loglogavgslope": (-1.5, 1e-2), + "flexibility": (5e-1, 1e-1), + "asperity": (5e-1, 5e-2), + "harmonic_domain_type": "Fourier" +} +dims = jnp.array([2**depth]) + +cfm = jft.CorrelatedFieldMaker("cf") +cfm.set_amplitude_total_offset(**cf_zm) +cfm.add_fluctuations(dims, distances=1. / dims.shape[0], **cf_fl, prefix="ax1") +correlated_field, ptree = cfm.finalize() +key = random.PRNGKey(42) + +pos_truth = jft.random_like(key, ptree) +plt.plot(correlated_field(pos_truth)) +plt.show() + +# %% +d = correlated_field(pos_truth) +lh = jax.jit( + lambda x: ((d - fwd(x))**2).sum() + + sum([(el**2).sum() for el in jax.tree_util.tree_leaves(x)]) +) +print(lh(xi)) + +# %% +opt_state = jft.minimize( + lh, + jft.Field(xi), + method="newton-cg", + options={ + "name": "N", + "absdelta": 0.1, + "maxiter": 30 + } +) + +# %% +plt.plot(correlated_field(pos_truth), label="truth") +plt.plot(fwd(opt_state.x), label="reconstruction") +plt.legend() +plt.show() + +# %% +pos_rec = opt_state.x.val.copy() +pos_rec["layer_weights"] = pos_rec["layer_weights"].at[:-8].set(0.) + +pos_truth = jft.random_like(key, ptree) +plt.plot(correlated_field(pos_truth), alpha=0.7, label="truth") +plt.plot(fwd(opt_state.x), alpha=0.7, label="reconstruction") +plt.plot(fwd(pos_rec), alpha=0.7, label="reconstruction coarse") +plt.legend() +plt.show() diff --git a/demos/re/hmc_multimodality.py b/demos/re/hmc_multimodality.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee563fb545b40e912d5673a13787c2e2c61771d --- /dev/null +++ b/demos/re/hmc_multimodality.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +# %% +from functools import partial + +import jax.numpy as jnp +import matplotlib.pyplot as plt + +import nifty8.re as jft + + +def loggaussian(x, mu, sigma): + return -0.5 * (x - mu)**2 / sigma + + +def sum_of_gaussians(x, separation, sigma1, sigma2): + return -jnp.logaddexp( + loggaussian(x, 0, sigma1), loggaussian(x, separation, sigma2) + ) + + +ham = partial(sum_of_gaussians, separation=10., sigma1=1., sigma2=1.) + +N = 100000 +SEED = 43 +EPS = 0.3 + +subplots = (2, 2) +fig_width_pt = 426 # pt (a4paper, and such) +# fig_width_pt = 360 # pt +inches_per_pt = 1 / 72.27 +fig_width_in = 0.9 * fig_width_pt * inches_per_pt +fig_height_in = fig_width_in * 0.618 * (subplots[0] / subplots[1]) +fig_dims = (fig_width_in, fig_height_in * 1.5) + +fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( + subplots[0], + subplots[1], + sharex='col', + figsize=fig_dims, + gridspec_kw={'width_ratios': [1, 2]} +) + +# %% +nuts_sampler = jft.NUTSChain( + potential_energy=ham, + inverse_mass_matrix=5., + position_proto=jnp.array(0.), + step_size=EPS, + max_tree_depth=15, + max_energy_difference=1000., +) + +chain, _ = nuts_sampler.generate_n_samples( + SEED, jnp.array(3.), num_samples=N, save_intermediates=True +) +print(f"small mass matrix acceptance: {chain.acceptance}") + +ax1.hist(chain.samples, bins=30, density=True) +ax2.plot(chain.samples, linewidth=0.5) + +ax1.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$') +ax2.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$') + +# %% +nuts_sampler = jft.NUTSChain( + potential_energy=ham, + inverse_mass_matrix=50., + position_proto=jnp.array(0.), + step_size=EPS, + max_tree_depth=15, + max_energy_difference=1000., +) + +chain, _ = nuts_sampler.generate_n_samples( + SEED, jnp.array(3.), num_samples=N, save_intermediates=True +) +print(f"large mass matrix acceptance: {chain.acceptance}") + +ax3.hist(chain.samples, bins=30, density=True) +ax4.plot(chain.samples, linewidth=0.5) + +ax3.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$') +ax4.set_title(rf'$m={1. / nuts_sampler.inverse_mass_matrix:1.2f}$') + +# %% +xs = jnp.linspace(-10, 20, num=500) +Z = jnp.trapz(jnp.exp(-ham(xs)), xs) +ax1.plot(xs, jnp.exp(-ham(xs)) / Z, linewidth=0.5, c='r') +ax3.plot(xs, jnp.exp(-ham(xs)) / Z, linewidth=0.5, c='r') + +ax1.set_ylabel('frequency') +ax2.set_ylabel('position') +ax3.set_xlabel('position') +ax3.set_ylabel('frequency') +ax4.set_xlabel('time') +ax4.set_ylabel('position') + +#fig.suptitle("sum of two Gaussians, with different choices of mass matrix") + +fig.tight_layout() +fig.savefig("multimodal.pdf", bbox_inches='tight') +print("final figure saved as multimodal.pdf") diff --git a/demos/re/hmc_nuts_trajectories.py b/demos/re/hmc_nuts_trajectories.py new file mode 100644 index 0000000000000000000000000000000000000000..340fc8ca2eb40caab298c67f8fe716aae48c114a --- /dev/null +++ b/demos/re/hmc_nuts_trajectories.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +# %% +# +# WARNING: This code does not behave deterministically. It works fine when +# executing cell by cell using vscodes notebook functionality but when running +# from the command line with either python3 or ipython3 the following happens: +# This is probably due to an issue with host_callback. +# Concretely it just stops adding points to the debug list after some random +# number of leapfrog steps. +# + +import jax.numpy as jnp +import matplotlib +import matplotlib.pyplot as plt + +import nifty8.re as jft + +# %% +jft.hmc._DEBUG_FLAG = True + +# %% +cov = jnp.array([10., 1.]) + +potential_energy = lambda q: jnp.sum(0.5 * q**2 / cov) + +initial_position = jnp.array([1., 1.]) + +sampler = jft.NUTSChain( + potential_energy=potential_energy, + inverse_mass_matrix=1., + position_proto=initial_position, + step_size=0.12, + max_tree_depth=10, +) + +# %% +jft.hmc._DEBUG_STORE = [] +jft.hmc._DEBUG_TREE_END_IDXS = [] +jft.hmc._DEBUG_SUBTREE_END_IDXS = [] + +chain, _ = sampler.generate_n_samples( + 48, initial_position, num_samples=5, save_intermediates=True +) + +plt.hist(chain.depths) +plt.show() + +# %% +debug_pos = jnp.array([qp.position for qp in jft.hmc._DEBUG_STORE]) +print(len(debug_pos)) + +# %% +prop_cycle = plt.rcParams['axes.prop_cycle'] +colors = prop_cycle.by_key()['color'] + +ax = plt.gca() +ellipse = matplotlib.patches.Ellipse( + xy=(0, 0), + width=jnp.sqrt(cov[0]), + height=jnp.sqrt(cov[1]), + edgecolor='k', + fc='None', + lw=1 +) +ax.add_patch(ellipse) + +color_idx = 0 +start_and_end_idxs = zip( + [ + 0, + ] + jft.hmc._DEBUG_SUBTREE_END_IDXS[:-1], jft.hmc._DEBUG_SUBTREE_END_IDXS +) +for start_idx, end_idx in start_and_end_idxs: + slice = debug_pos[start_idx:end_idx] + ax.plot( + slice[:, 0], + slice[:, 1], + '-o', + markersize=1, + linewidth=0.5, + color=colors[color_idx % len(colors)] + ) + if end_idx in jft.hmc._DEBUG_TREE_END_IDXS: + color_idx = (color_idx + 1) % len(colors) + +ax.scatter( + chain.samples[:, 0], + chain.samples[:, 1], + marker='x', + color='k', + label='samples' +) +ax.scatter(initial_position[0], initial_position[1], label='starting position') +ax.set_xlabel('x') +ax.set_ylabel('y') +ax.legend() + +fig_width_pt = 426 # pt (a4paper, and such) +# fig_width_pt = 360 # pt +inches_per_pt = 1 / 72.27 +fig_width_in = 0.9 * fig_width_pt * inches_per_pt +fig_height_in = fig_width_in * 0.618 +fig_dims = (fig_width_in, fig_height_in) + +plt.tight_layout() +plt.show() +plt.savefig("trajectories.pdf", bbox_inches='tight') diff --git a/demos/re/hmc_wiener_filter.py b/demos/re/hmc_wiener_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..8664e9937d44de8aac90e672176703768cb6a02c --- /dev/null +++ b/demos/re/hmc_wiener_filter.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +#%% +from jax import numpy as jnp +from jax import lax, random +import jax +from jax.config import config +import matplotlib +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + +matplotlib.rcParams['figure.figsize'] = (10, 7) + +#%% +dims = (512, ) +#datadims = (4,) +loglogslope = 2. +power_spectrum = lambda k: 1. / (k**loglogslope + 1.) +modes = jnp.arange((dims[0] / 2) + 1., dtype=float) +harmonic_power = power_spectrum(modes) +harmonic_power = jnp.concatenate((harmonic_power, harmonic_power[-2:0:-1])) + +#%% +correlated_field = lambda x: jft.correlated_field.hartley( + # x is a signal in fourier space + # each modes amplitude gets multiplied by it's harmonic_power + # and the whole signal is transformed back + harmonic_power * x +) + +# %% [markdown] +# signal_response = lambda x: jnp.exp(1. + correlated_field(x)) +signal_response = lambda x: correlated_field(x) +# The signal response is $ \vec{d} = \begin{pmatrix} 1 \\ 1 \\ 1 \\ 1 \end{pmatrix} \cdot s + \vec{n} $ where $s \in \mathbb{R}$ and $\vec{n} \sim \mathcal{G}(0, N)$ +# signal_response = lambda x: jnp.ones(shape=datadims) * x +# ??? +noise_cov_inv_sqrt = lambda x: 1.**-1 * x + +#%% +# create synthetic data +seed = 43 +key = random.PRNGKey(seed) +key, subkey = random.split(key) +# normal random fourier amplitude +pos_truth = random.normal(shape=dims, key=subkey) +signal_response_truth = signal_response(pos_truth) +key, subkey = random.split(key) +# 1. / noise_cov_inv_sqrt(jnp.ones(dims)) becomes the standard deviation of the noise gaussian +noise_truth = 1. / noise_cov_inv_sqrt(jnp.ones(dims) + ) * random.normal(shape=dims, key=subkey) +data = signal_response_truth + noise_truth + +#%% +plt.plot(signal_response_truth, label='signal response') +#plt.plot(noise_truth, label='noise', linewidth=0.5) +plt.plot(data, 'k.', label='noisy data', markersize=4.) +plt.xlabel('real space domain') +plt.ylabel('field value') +plt.legend() +plt.title("signal and data") +plt.show() + + +#%% +def Gaussian(data, noise_cov_inv_sqrt): + # Simple but not very generic Gaussian energy + # primals + def hamiltonian(primals): + p_res = primals - data + # TODO: is this the weighting with noies amplitude thing again? + l_res = noise_cov_inv_sqrt(p_res) + return 0.5 * jnp.sum(l_res**2) + + return jft.Likelihood(hamiltonian, ) + + +# negative log likelihood +nll = Gaussian(data, noise_cov_inv_sqrt) @ signal_response + +#%% +ham = jft.StandardHamiltonian(likelihood=nll) +ham_gradient = jax.grad(ham) + + +# %% [markdown] +def plot_mean_and_stddev(ax, samples, mean_of_r=None, truth=False, **kwargs): + signal_response_of_samples = lax.map(signal_response, samples) + if mean_of_r == None: + mean_of_signal_response = jnp.mean(signal_response_of_samples, axis=0) + else: + mean_of_signal_response = mean_of_r + mean_label = kwargs.pop('mean_label', 'sample mean of signal response') + ax.plot(mean_of_signal_response, label=mean_label) + std_dev_of_signal_response = jnp.std(signal_response_of_samples, axis=0) + if truth: + ax.plot(signal_response_truth, label="truth") + ax.fill_between( + jnp.arange(len(mean_of_signal_response)), + y1=mean_of_signal_response - std_dev_of_signal_response, + y2=mean_of_signal_response + std_dev_of_signal_response, + color='grey', + alpha=0.5 + ) + title = kwargs.pop('title', 'position samples') + if title is not None: + ax.set_title(title) + xlabel = kwargs.pop('xlabel', 'position') + if xlabel is not None: + ax.set_xlabel(xlabel) + ylabel = kwargs.pop('ylabel', 'signal response') + if ylabel is not None: + ax.set_ylabel(ylabel) + ax.legend(loc='lower right', fontsize=8) + + +#%% +key, subkey = random.split(key) +initial_position = random.uniform(key=subkey, shape=pos_truth.shape) + +sampler = jft.HMCChain( + potential_energy=ham, + inverse_mass_matrix=1., + position_proto=initial_position, + step_size=0.05, + num_steps=128, +) + +chain, _ = sampler.generate_n_samples( + 42, initial_position, num_samples=30, save_intermediates=True +) +print(f"acceptance ratio: {chain.acceptance}") + +# %% +plot_mean_and_stddev(plt.gca(), chain.samples, truth=True) +plt.title("HMC position samples") +plt.show() + +# %% [markdown] +# # NUTS +jft.hmc._DEBUG_STORE = [] + +sampler = jft.NUTSChain( + position_proto=initial_position, + potential_energy=ham, + inverse_mass_matrix=1., + # 0.9193 # integrates to ~3-7, very smooth sample mean + # 0.8193 # integrates to depth ~22, very noisy sample mean + step_size=0.05, + max_tree_depth=17, +) + +chain, _ = sampler.generate_n_samples( + 42, initial_position, num_samples=30, save_intermediates=True +) +plt.hist(chain.depths, bins=jnp.arange(sampler.max_tree_depth + 2)) +plt.title('NUTS tree depth histogram') +plt.xlabel('tree depth') +plt.ylabel('count') +plt.show() + +# %% +plot_mean_and_stddev(plt.gca(), chain.samples, truth=True) +plt.title("NUTS position samples") +plt.show() + +# %% +if jft.hmc._DEBUG_FLAG: + debug_pos = jnp.array(jft.hmc._DEBUG_STORE)[:, 0, :] + + for idx, dbgp in enumerate(debug_pos): + plt.plot(signal_response(dbgp), label=f'{idx}', alpha=0.1) + #plt.legend() + + # %% + debug_pos_x = jnp.array(jft.hmc._DEBUG_STORE)[:, 0, 0] + debug_pos_y = jnp.array(jft.hmc._DEBUG_STORE)[:, 0, 1] + for idx, dbgp in enumerate(debug_pos): + plt.scatter(debug_pos_x, debug_pos_y, s=0.1, color='k') + #plt.legend() + plt.show() + +# %%[markdown] +# # 1D position and momentum time series +if chain.samples[0].shape == (1, ): + plt.plot(chain.samples, label='position') + #plt.plot(momentum_samples, label='momentum', linewidth=0.2) + #plt.plot(unintegrated_momenta, label='unintegrated momentum', linewidth=0.2) + plt.title('position and momentum time series') + plt.xlabel('time') + plt.ylabel('position, momentum') + plt.legend() + plt.show() + +# %% [markdown] +# # energy time series +potential_energies = lax.map(ham, chain.samples) +kinetic_energies = jnp.sum(chain.trees.proposal_candidate.momentum**2, axis=1) +#rejected_potential_energies = lax.map(ham, rejected_position_samples) +#rejected_kinetic_energies = jnp.sum(rejected_momentum_samples**2, axis=1) +plt.plot(potential_energies, label='pot') +plt.plot(kinetic_energies, label='kin', linewidth=1) +plt.plot(kinetic_energies + potential_energies, label='total', linewidth=1) +#plt.plot(rejected_potential_energies , label='rejected_pot') +#plt.plot(rejected_kinetic_energies , label='rejected_kin', linewidth=2) +#plt.plot(rejected_kinetic_energies + rejected_potential_energies, label='rejected_total', linewidth=0.2) +plt.title('NUTS energy time series') +plt.xlabel('time') +plt.ylabel('energy') +plt.yscale('log') +plt.legend() +plt.show() + +# %% [markdown] +# # Wiener Filter + +# jax.linear_transpose for R^\dagger +# square noise_sqrt_inv ... for N^-1 +# S is unit due to StandardHamiltonian +# jax.scipy.sparse.linalg.cg for D + +# signal_response(s) is only needed for shape of data space +_impl_signal_response_dagger = jax.linear_transpose(signal_response, pos_truth) +signal_response_dagger = lambda d: _impl_signal_response_dagger(d)[0] +# noise_cov_inv_sqrt is diagonal +noise_cov_inv = lambda d: noise_cov_inv_sqrt(noise_cov_inv_sqrt(d)) + +# signal prior covariance S is assumed to be unit (see StandardHamiltonian) +# the tranposed function wierdly returns a (1,)-tuple which we unpack right here +D_inv = lambda s: s + signal_response_dagger(noise_cov_inv(signal_response(s))) + +j = signal_response_dagger(noise_cov_inv(data)) + +m, _ = jax.scipy.sparse.linalg.cg(D_inv, j) + +# %% + +# TODO fix labels +plt.plot(signal_response(m), label='signal response of mean') +plt.plot(signal_response_truth, label='true signal response') +plt.legend() +plt.title('Wiener Filter') +plt.show() + + +# %% +def sample_from_d_inv(key): + s_inv_key, rnr_key = random.split(key) + s_inv_smpl = random.normal(s_inv_key, pos_truth.shape) + # random.normal sample from dataspace and then R^\dagger \sqrt{N^{-1}} + # jax.eval_shape(signal_response, pos_truth) + rnr_smpl = signal_response_dagger( + noise_cov_inv_sqrt(random.normal(rnr_key, signal_response_truth.shape)) + ) + return s_inv_smpl + rnr_smpl + + +def sample_from_d(key): + d_inv_smpl = sample_from_d_inv(key) + # TODO: what to do here? + smpl, _ = jft.cg(D_inv, d_inv_smpl, maxiter=32) + return smpl + + +wiener_samples = jnp.array( + list(map(lambda key: sample_from_d(key) + m, random.split(key, 30))) +) + +# %% +subplots = (3, 1) +fig_height_pt = 541 # pt +#fig_width_pt = 360 # pt +inches_per_pt = 1 / 72.27 +fig_height_in = 1. * fig_height_pt * inches_per_pt +fig_width_in = fig_height_in / 0.618 * (subplots[1] / subplots[0]) +fig_dims = (fig_width_in, fig_height_in) + +fig, (ax_raw, ax_nuts, ax_wiener) = plt.subplots( + subplots[0], subplots[1], sharex=True, sharey=False, figsize=fig_dims +) + +ax_raw.plot(signal_response_truth, label='true signal response') +ax_raw.plot(data, 'k.', label='noisy data', markersize=2.) +#ax_raw.set_xlabel('position') +ax_raw.set_ylabel('signal response') +ax_raw.set_title("signal and data") +ax_raw.legend(fontsize=8) + +plot_mean_and_stddev( + ax_nuts, + chain.samples, + truth=True, + title="NUTS", + xlabel=None, + mean_label='sample mean' +) +plot_mean_and_stddev( + ax_wiener, + wiener_samples, + mean_of_r=signal_response(m), + truth=True, + title="Wiener Filter", + mean_label='exact posterior mean' +) + +fig.tight_layout() + +plt.savefig('wiener.pdf', bbox_inches='tight') +print("final plot saved as wiener.pdf") diff --git a/demos/re/lognorm_w_hmc.py b/demos/re/lognorm_w_hmc.py new file mode 100644 index 0000000000000000000000000000000000000000..d519b98de4f27e8bbfc8109d3c03cd9c6a62c328 --- /dev/null +++ b/demos/re/lognorm_w_hmc.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +# %% +from functools import partial +import sys + +from jax import numpy as jnp +from jax import lax, random +from jax import jit +from jax.config import config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + +seed = 42 +key = random.PRNGKey(seed) + + +# %% +def cartesian_product(arrays, out=None): + import numpy as np + + # Generalized N-dimensional products + arrays = [np.asarray(x) for x in arrays] + la = len(arrays) + dtype = np.find_common_type([a.dtype for a in arrays], []) + if out is None: + out = np.empty([len(a) for a in arrays] + [la], dtype=dtype) + for i, a in enumerate(np.ix_(*arrays)): + out[..., i] = a + return out.reshape(-1, la) + + +def helper_phi_b(b, x): + return b * x[0] * jnp.exp(b * x[1]) + + +# %% +b = 2. + +signal_response = partial(helper_phi_b, b) +nll = jft.Gaussian(0., lambda x: x / jnp.sqrt(1.)) @ signal_response + +ham = jft.StandardHamiltonian(nll).jit() +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) +GeoMetricKL = partial(jft.GeoMetricKL, ham) + +# %% +n_pix_sqrt = 1000 +x = jnp.linspace(-4, 4, n_pix_sqrt) +y = jnp.linspace(-4, 4, n_pix_sqrt) +xx = cartesian_product((x, y)) +ham_everywhere = jnp.vectorize(ham, signature="(2)->()")(xx).reshape( + n_pix_sqrt, n_pix_sqrt +) +plt.imshow( + jnp.exp(-ham_everywhere.T), + extent=(x.min(), x.max(), y.min(), y.max()), + origin="lower" +) +plt.colorbar() +plt.title("target distribution") +plt.show() + +# %% +n_mgvi_iterations = 30 +n_samples = [2] * (n_mgvi_iterations - 10) + [2] * 5 + [10, 10, 10, 10, 100] +n_newton_iterations = [7] * (n_mgvi_iterations - 10) + [10] * 6 + 4 * [25] +absdelta = 1e-13 + +initial_position = jnp.array([1., 1.]) +mkl_pos = 1e-2 * jft.Field(initial_position) + +mgvi_positions = [] + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + mg_samples = MetricKL( + mkl_pos, + n_samples[i], + key=subkey, + mirror_samples=True, + linear_sampling_kwargs={"absdelta": absdelta / 10.}, + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=mkl_pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=mg_samples), + "hessp": partial(ham_metric, primals_samples=mg_samples), + "absdelta": absdelta, + "maxiter": n_newton_iterations[i], + "cg_kwargs": { + "name": None + }, + "name": "N" + } + ) + mkl_pos = opt_state.x + msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(mkl_pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + mgvi_positions.append(mkl_pos) + +# %% +n_geovi_iterations = 15 +n_samples = [1] * (n_geovi_iterations - 10) + [2] * 5 + [10, 10, 10, 10, 100] +n_newton_iterations = [7] * (n_geovi_iterations - 10) + [10] * 6 + [25] * 4 +absdelta = 1e-10 + +initial_position = jnp.array([1., 1.]) +gkl_pos = 1e-2 * jft.Field(initial_position) + +for i in range(n_geovi_iterations): + print(f"geoVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + geo_samples = GeoMetricKL( + gkl_pos, + n_samples[i], + key=subkey, + mirror_samples=True, + linear_sampling_name=None, + linear_sampling_kwargs={"absdelta": absdelta / 10.}, + non_linear_sampling_kwargs={ + "cg_kwargs": { + "miniter": 0 + }, + "maxiter": 20 + }, + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=gkl_pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=geo_samples), + "hessp": partial(ham_metric, primals_samples=geo_samples), + "absdelta": absdelta, + "maxiter": n_newton_iterations[i], + "cg_kwargs": { + "miniter": 0, + "name": None + }, + "name": "N" + } + ) + gkl_pos = opt_state.x + msg = f"Post geoVI Iteration {i}: Energy {geo_samples.at(gkl_pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +# %% +n_pix_sqrt = 200 +x = jnp.linspace(-4.0, 4.0, n_pix_sqrt, endpoint=True) +y = jnp.linspace(-4.0, 4.0, n_pix_sqrt, endpoint=True) +X, Y = jnp.meshgrid(x, y) +XY = jnp.array([X, Y]).T +xy = XY.reshape((XY.shape[0] * XY.shape[1], 2)) +es = jnp.exp(-lax.map(ham, xy)).reshape(XY.shape[:2]).T + +# %% +mkl_b_space_smpls = jnp.array([s.val for s in mg_samples.at(mkl_pos)]) + +fig, ax = plt.subplots() +contour = ax.contour(X, Y, es) +ax.clabel(contour, inline=True, fontsize=10) +ax.scatter(*mkl_b_space_smpls.T) +ax.plot(*mkl_pos, "rx") +plt.title("MGVI") +plt.show() + +# %% +gkl_b_space_smpls = jnp.array([s.val for s in geo_samples.at(gkl_pos)]) + +fig, ax = plt.subplots() +contour = ax.contour(X, Y, es) +ax.clabel(contour, inline=True, fontsize=10) +ax.scatter(*gkl_b_space_smpls.T) +ax.plot(*gkl_pos, "rx") +plt.title("GeoVI") +plt.show() + +# %% +initial_position = jnp.array([1., 1.]) + +hmc_sampler = jft.HMCChain( + potential_energy=ham, + inverse_mass_matrix=1., + position_proto=initial_position, + step_size=0.1, + num_steps=64, +) + +chain, _ = hmc_sampler.generate_n_samples( + 42, 1e-2 * initial_position, num_samples=100, save_intermediates=True +) + +# %% +b_space_smpls = chain.samples +fig, ax = plt.subplots() +ax.scatter(*b_space_smpls.T) +plt.title("HMC (Metroplis-Hastings) samples") +plt.show() + +# %% +initial_position = jnp.array([1., 1.]) + +nuts_sampler = jft.NUTSChain( + potential_energy=ham, + inverse_mass_matrix=0.5, + position_proto=initial_position, + step_size=0.4, + max_tree_depth=10, +) + +nuts_n_samples = [] +ns_samples = [200, 1000, 1000000] +for n_samples in ns_samples: + chain, _ = nuts_sampler.generate_n_samples( + 43 + n_samples, + 1e-2 * initial_position, + num_samples=n_samples, + save_intermediates=True + ) + nuts_n_samples.append(chain.samples) + +# %% +b_space_smpls = chain.samples + +fig, ax = plt.subplots() +contour = ax.contour(X, Y, es) +ax.clabel(contour, inline=True, fontsize=10) +ax.scatter(*b_space_smpls.T, s=2.) +plt.show() + +# %% +plt.hist2d( + *b_space_smpls.T, + bins=[x, y], + range=[[x.min(), x.max()], [y.min(), y.max()]] +) +plt.colorbar() +plt.show() + +# %% +subplots = (3, 2) + +fig_width_pt = 426 # pt (a4paper, and such) +inches_per_pt = 1 / 72.27 +fig_width_in = fig_width_pt * inches_per_pt +fig_height_in = fig_width_in * 1. * (subplots[0] / subplots[1]) +fig_dims = (fig_width_in, fig_height_in) + +fig, ((ax1, ax4), (ax2, ax5), (ax3, ax6) + ) = plt.subplots(*subplots, figsize=fig_dims, sharex=True, sharey=True) + +ax1.set_title(r'$P(d=0|\xi_1, \xi_2) \cdot P(\xi_1, \xi_2)$') +xx = cartesian_product((x, y)) +ham_everywhere = jnp.vectorize(ham, signature="(2)->()")(xx).reshape( + n_pix_sqrt, n_pix_sqrt +) +ax1.imshow( + jnp.exp(-ham_everywhere.T), + extent=(x.min(), x.max(), y.min(), y.max()), + origin="lower" +) +#ax1.colorbar() + +ax1.set_ylim([-4., 4.]) +ax1.set_xlim([-4., 4.]) +#ax1.autoscale(enable=True, axis='y', tight=True) +asp = float( + jnp.diff(jnp.array(ax1.get_xlim()))[0] / + jnp.diff(jnp.array(ax1.get_ylim()))[0] +) + +smplmarkersize = .3 +smplmarkercolor = 'k' + +linewidths = 0.5 +fontsize = 5 +potlabels = False + +ax2.set_title('MGVI') +mkl_b_space_smpls = jnp.array([s.val for s in mg_samples.at(mkl_pos)]) +contour = ax2.contour(X, Y, es, linewidths=linewidths) +ax2.clabel(contour, inline=True, fontsize=fontsize) +ax2.scatter(*mkl_b_space_smpls.T, s=smplmarkersize, c=smplmarkercolor) +ax2.plot(*mkl_pos, "rx") +#ax2.set_aspect(asp) + +ax3.set_title('geoVI') +gkl_b_space_smpls = jnp.array([s.val for s in geo_samples.at(gkl_pos)]) +contour = ax3.contour(X, Y, es, linewidths=linewidths) +ax3.clabel(contour, inline=True, fontsize=fontsize) +ax3.scatter(*gkl_b_space_smpls.T, s=smplmarkersize, c=smplmarkercolor) +ax3.plot(*gkl_pos, "rx") +#ax3.set_aspect(asp) + +for i in range(3): + eval('ax' + str(i + 1)).set_ylabel(r'$\xi_2$') +ax3.set_xlabel(r'$\xi_1$') +ax6.set_xlabel(r'$\xi_1$') + +for n, samples, ax in zip(ns_samples[:2], nuts_n_samples[:2], [ax4, ax5]): + ax.set_title(f"NUTS N={n}") + contour = ax.contour(X, Y, es, linewidths=linewidths) + #ax.clabel(contour, inline=True, fontsize=fontsize) + ax.scatter(*samples.T, s=smplmarkersize, c=smplmarkercolor) + +h, _, _ = jnp.histogram2d( + *nuts_n_samples[-1].T, + bins=[x, y], + range=[[x.min(), x.max()], [y.min(), y.max()]] +) +ax6.imshow(h.T, extent=(x.min(), x.max(), y.min(), y.max()), origin="lower") +ax6.set_title(f'NUTS N={ns_samples[-1]:.0E}') + +fig.tight_layout() +fig.savefig("pinch.pdf", bbox_inches='tight') +print("final plot saved as pinch.pdf") diff --git a/demos/re/nifty_to_jifty.py b/demos/re/nifty_to_jifty.py new file mode 100644 index 0000000000000000000000000000000000000000..2b82053ebd9c1eaacc7aa2ba741fee4f9476e715 --- /dev/null +++ b/demos/re/nifty_to_jifty.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys + +from jax import numpy as jnp +from jax import random +from jax import jit +from jax.config import config +import matplotlib.pyplot as plt + +import nifty8.re as jft + +config.update("jax_enable_x64", True) + +# %% +# ## Likelihood +# +# ### What is a Likelihood in jifty? +# +# * Very generally, the likelihood stores the cost term(s) for the final minimization +# * P(d|\xi) is a likelihood just like P(\xi) is a likelihood (w/ d := data, \xi := parameters) +# * Adding two likelihoods yields a likelihood again; thus P(d|\xi) + P(\xi) is just another likelihood +# * Properties +# * Energy/Hamiltonian: negative log-probability +# * Left square root (L) of the metric (M; M = L L^\dagger): needed for sampling and minimization +# * Metric: needed for sampling and minimization; can be inferred from left sqrt metric +# +# ### Differences to NIFTy's `EnergyOperator`? +# +# * There are no operators in jifty, thus there is no EnergyOperator! +# * NIFTy features many different energies classes; in jifty there is just one +# * jifty needs to track the domain of the data without re-introducing operators +# +# ### What gives? +# +# * No manual tracking of the jacobian +# * No linear operators; this also means we can not take the adjoint of the jacobian :( +# * Trivial to define new likelihoods + + +def Gaussian(data, noise_cov_inv_sqrt): + # Simple but not very generic Gaussian energy + def hamiltonian(primals): + p_res = primals - data + l_res = noise_cov_inv_sqrt(p_res) + return 0.5 * jnp.sum(l_res**2) + + def left_sqrt_metric(primals, tangents): + return noise_cov_inv_sqrt(tangents) + + lsm_tangents_shape = jnp.shape(data) + # Better: `tree_map(ShapeWithDtype.from_leave, data)` + + return jft.Likelihood( + hamiltonian, + left_sqrt_metric=left_sqrt_metric, + lsm_tangents_shape=lsm_tangents_shape + ) + + +seed = 42 +key = random.PRNGKey(seed) + +dims = (1024, ) + +loglogslope = 2. +power_spectrum = lambda k: 1. / (k**loglogslope + 1.) +modes = jnp.arange((dims[0] / 2) + 1., dtype=float) +harmonic_power = power_spectrum(modes) +harmonic_power = jnp.concatenate((harmonic_power, harmonic_power[-2:0:-1])) + +# Specify the model +correlated_field = lambda x: jft.correlated_field.hartley( + harmonic_power * x.val +) +signal_response = lambda x: jnp.exp(1. + correlated_field(x)) +noise_cov_inv_sqrt = lambda x: 0.1**-1 * x + +# Create synthetic data +key, subkey = random.split(key) +pos_truth = jft.Field(random.normal(shape=dims, key=key)) +signal_response_truth = signal_response(pos_truth) +key, subkey = random.split(key) +noise_truth = 1. / noise_cov_inv_sqrt(jnp.ones(dims) + ) * random.normal(shape=dims, key=key) +data = signal_response_truth + noise_truth + +nll = Gaussian(data, noise_cov_inv_sqrt) @ signal_response +ham = jft.StandardHamiltonian(likelihood=nll).jit() +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) + +key, subkey = random.split(key) +pos_init = jft.Field(random.normal(shape=dims, key=subkey)) +pos = jft.Field(pos_init.val) + +n_newton_iterations = 10 +# Maximize the posterior using natural gradient scaling +pos = jft.newton_cg( + fun_and_grad=ham_vg, x0=pos, hessp=ham_metric, maxiter=n_newton_iterations +) + +fig, ax = plt.subplots() +ax.plot(signal_response_truth, alpha=0.7, label="Signal") +ax.plot(noise_truth, alpha=0.7, label="Noise") +ax.plot(data, alpha=0.7, label="Data") +ax.plot(signal_response(pos), alpha=0.7, label="Reconstruction") +ax.legend() +fig.tight_layout() +fig.savefig("n2f_known_spectrum_MAP.png", dpi=400) +plt.close() + +# ## Sampling +# +# ### How sampling works in jifty? +# +# To sample from a likelihood, we need to be able to draw samples which have +# the metric as covariance structure and we need to be able to apply the +# inverse metric. The first part is trivial since we can use the left square +# root of the metric associated with every likelihood: +# +# \tilde{d} \leftarrow \mathcal{G}(0,\mathbb{1}) +# t = L \tilde{d} +# +# with $t$ now having a covariance structure of +# +# <t t^\dagger> = L <\tilde{d} \tilde{d}^\dagger> L^\dagger = M. +# +# We now need to apply the inverse metric in order to transform the sample to +# an inverse sample. We can do so using the conjugate gradient algorithm which +# yields the solution to $M s = t$, i.e. applies the inverse of $M$ to $t$: +# +# M s = t +# s = M^{-1} t = cg(M, t) . +# +# ### Differences to NIFTy? +# +# * More generic implementation since the left square root of the metric can +# be applied independently from drawing samples +# * By virtue of storing the left square root metric, no dedicated sampling +# method needs to be extended ever again +# +# ### What gives? +# +# The clearer separation of sampling and inverting the metric allows for a +# better interplay of our methods with existing tools like JAX's cg +# implementation. + +n_mgvi_iterations = 3 +n_samples = 4 +n_newton_iterations = 5 + +key, subkey = random.split(key) +pos_init = jft.Field(random.normal(shape=dims, key=subkey)) +pos = jft.Field(pos_init.val) + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + mg_samples = MetricKL( + pos, + n_samples=n_samples, + key=subkey, + mirror_samples=True, + ) + + print("Minimizing...", file=sys.stderr) + pos = jft.newton_cg( + fun_and_grad=partial(ham_vg, primals_samples=mg_samples), + x0=pos, + hessp=partial(ham_metric, primals_samples=mg_samples), + maxiter=n_newton_iterations + ) + msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +post_sr_mean = jft.mean(tuple(signal_response(s) for s in mg_samples.at(pos))) +fig, ax = plt.subplots() +ax.plot(signal_response_truth, alpha=0.7, label="Signal") +ax.plot(noise_truth, alpha=0.7, label="Noise") +ax.plot(data, alpha=0.7, label="Data") +ax.plot(post_sr_mean, alpha=0.7, label="Reconstruction") +label = "Reconstructed samples" +for s in mg_samples: + ax.plot(signal_response(s), color="gray", alpha=0.5, label=label) + label = None +ax.legend() +fig.tight_layout() +fig.savefig("n2f_known_spectrum_MGVI.png", dpi=400) +plt.close() + +# ## Correlated field +# +# ### Correlated fields in jifty +# +# * `CorrelatedFieldMaker` to track amplitudes along different axes +# * `add_fluctuations` method to amend new amplitudes +# * Zero-mode is tracked separately to the amplitudes +# * `finalize` normalizes the amplitudes and takes their outer product +# * Amplitudes are independent of the stack of amplitudes tracked in the correlated field, i.e. no normalization happens within the amplitude +# +# ### Differences to NIFTy +# +# A correlated field with a single axis but arbitrary dimensionality in NIFTy +# is mostly equivalent to one in jifty. Though since jifty does not track +# domains, everything related to harmonic modes and distributing power is +# contained within the correlated field model. +# +# The normalization and factorization of amplitudes is done only once in +# `finalize`. This conceptually simplifies the amplitude model by a lot. +# +# ### What gives? +# +# * Conceptually simpler amplitude model +# * No domains --> no domain mismatches --> broadcasting \o/ +# * No domains --> no domain mismatches --> more errors :( + +dims_ax1 = (64, ) +dims_ax2 = (128, ) +cf_zm = {"offset_mean": 0., "offset_std": (1e-3, 1e-4)} +cf_fl = { + "fluctuations": (1e-1, 5e-3), + "loglogavgslope": (-1., 1e-2), + "flexibility": (1e+0, 5e-1), + "asperity": (5e-1, 1e-1), + "harmonic_domain_type": "Fourier" +} +cfm = jft.CorrelatedFieldMaker("cf") +cfm.set_amplitude_total_offset(**cf_zm) +d = 1. / dims_ax1[0] +cfm.add_fluctuations(dims_ax1, distances=d, **cf_fl, prefix="ax1") +d = 1. / dims_ax2[0] +cfm.add_fluctuations(dims_ax2, distances=d, **cf_fl, prefix="ax2") +correlated_field, ptree = cfm.finalize() + +signal_response = lambda x: correlated_field(x) +noise_cov = lambda x: 5**2 * x +noise_cov_inv = lambda x: 5**-2 * x + +# Create synthetic data +key, subkey = random.split(key) +pos_truth = jft.random_like(subkey, ptree) +signal_response_truth = signal_response(pos_truth) +key, subkey = random.split(key) +noise_truth = jnp.sqrt( + noise_cov(jnp.ones(signal_response_truth.shape)) +) * random.normal(shape=signal_response_truth.shape, key=key) +data = signal_response_truth + noise_truth + +nll = jft.Gaussian(data, noise_cov_inv) @ signal_response +ham = jft.StandardHamiltonian(likelihood=nll).jit() +ham_vg = jit(jft.mean_value_and_grad(ham)) +ham_metric = jit(jft.mean_metric(ham.metric)) +MetricKL = jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) + +key, subkey = random.split(key) +pos_init = jft.Field(jft.random_like(subkey, ptree)) +pos = jft.Field(pos_init.val) + +n_mgvi_iterations = 3 +n_samples = 4 +n_newton_iterations = 10 + +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + mg_samples = MetricKL( + pos, + n_samples=n_samples, + key=subkey, + mirror_samples=True, + ) + + print("Minimizing...", file=sys.stderr) + pos = jft.newton_cg( + fun_and_grad=partial(ham_vg, primals_samples=mg_samples), + x0=pos, + hessp=partial(ham_metric, primals_samples=mg_samples), + maxiter=n_newton_iterations + ) + msg = f"Post MGVI Iteration {i}: Energy {mg_samples.at(pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +namps = cfm.get_normalized_amplitudes() +post_sr_mean = jft.mean(tuple(signal_response(s) for s in mg_samples.at(pos))) +post_namps1_mean = jft.mean(tuple(namps[0](s)[1:] for s in mg_samples.at(pos))) +post_namps2_mean = jft.mean(tuple(namps[1](s)[1:] for s in mg_samples.at(pos))) +to_plot = [ + ("Signal", signal_response_truth, "im"), + ("Noise", noise_truth, "im"), + ("Data", data, "im"), + ("Reconstruction", post_sr_mean, "im"), + ("Ax1", (namps[0](pos_truth)[1:], post_namps1_mean), "loglog"), + ("Ax2", (namps[1](pos_truth)[1:], post_namps2_mean), "loglog"), +] +fig, axs = plt.subplots(2, 3, figsize=(16, 9)) +for ax, (title, field, tp) in zip(axs.flat, to_plot): + ax.set_title(title) + if tp == "im": + im = ax.imshow(field, cmap="inferno") + plt.colorbar(im, ax=ax, orientation="horizontal") + else: + ax_plot = ax.loglog if tp == "loglog" else ax.plot + field = field if isinstance(field, (tuple, list)) else (field, ) + for f in field: + ax_plot(f, alpha=0.7) +fig.tight_layout() +fig.savefig("n2f_unknown_factorizing_spectra.png", dpi=400) +plt.close() diff --git a/demos/re/refine.py b/demos/re/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..872662fe6356c5fe2fdf6a8312bd5772870025b6 --- /dev/null +++ b/demos/re/refine.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 + +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from collections import namedtuple +from functools import partial +import sys + +import jax +from jax import numpy as jnp +from jax import random +from jax.scipy.interpolate import RegularGridInterpolator +import matplotlib.pyplot as plt +import numpy as np +from scipy.special import kv as mod_bessel2 + +import nifty8.re as jft + +jax.config.update("jax_enable_x64", True) +# jax.config.update("jax_debug_nans", True) + +Timed = namedtuple("Timed", ("time", "number"), rename=True) + + +def timeit(stmt, setup=lambda: None, number=None): + import timeit + + if number is None: + number, _ = timeit.Timer(stmt).autorange() + + setup() + t = timeit.timeit(stmt, number=number) / number + return Timed(time=t, number=number) + + +def _matern_kernel(distance, scale, cutoff, dof): + from jax.scipy.special import gammaln + + reg_dist = jnp.sqrt(2 * dof) * distance / cutoff + return scale**2 * 2**(1 - dof) / jnp.exp( + gammaln(dof) + ) * (reg_dist)**dof * mod_bessel2(dof, reg_dist) + + +n_dof = 100 +n_dist = 1000 +min_reg_dist = 1e-6 # approx. lowest resolution of `_matern_kernel` at float64 +max_reg_dist = 8e+2 # approx. highest resolution of `_matern_kernel` at float64 +eps = 8. * jnp.finfo(jnp.array(min_reg_dist).dtype.type).eps +dof_grid = np.linspace(0., 15., n_dof) +reg_dist_grid = np.logspace( + np.log(min_reg_dist * (1. - eps)), + np.log(max_reg_dist * (1. + eps)), + base=np.e, + num=n_dist +) +grid = np.meshgrid(dof_grid, reg_dist_grid, indexing="ij") +_unsafe_ln_mod_bessel2 = RegularGridInterpolator( + (dof_grid, reg_dist_grid), jnp.log(mod_bessel2(*grid)), fill_value=-np.inf +) + + +def matern_kernel(distance, scale, cutoff, dof): + from jax.scipy.special import gammaln + + reg_dist = jnp.sqrt(2 * dof) * distance / cutoff + dof, reg_dist = jnp.broadcast_arrays(dof, reg_dist) + + # Never produce NaNs (https://github.com/google/jax/issues/1052) + reg_dist = reg_dist.clip(min_reg_dist, max_reg_dist) + + ln_kv = jnp.squeeze( + _unsafe_ln_mod_bessel2(jnp.stack((dof, reg_dist), axis=-1)) + ) + corr = 2**(1 - dof) * jnp.exp(ln_kv - gammaln(dof)) * (reg_dist)**dof + return scale**2 * corr + + +scale, cutoff, dof = 1., 80., 3 / 2 + +# %% +x = np.logspace(-6, 11, base=jnp.e, num=int(1e+5)) +y = _matern_kernel(x, scale, cutoff, dof) +y = jnp.nan_to_num(y, nan=0.) +kernel = partial(jnp.interp, xp=x, fp=y) +kernel_j = partial(matern_kernel, scale=scale, cutoff=cutoff, dof=dof) + +fig, ax = plt.subplots() +x_s = x[x < 10 * cutoff] +ax.plot(x_s, kernel(x_s)) +ax.plot(x_s, kernel_j(x_s)) +ax.plot(x_s, jnp.exp(-(x_s / (2. * cutoff))**2)) +ax.set_yscale("log") +fig.savefig("re_refine_kernel.png", transparent=True) +plt.close() + +# %% +# Quick demo of the correlated field scheme that is to be used in the following +cf_kwargs = {"shape0": (12, ), "distances0": (50., ), "kernel": kernel} + +cf = jft.RefinementField(**cf_kwargs, depth=5) +xi = jft.random_like(random.PRNGKey(42), cf.shapewithdtype) + +fig, ax = plt.subplots(figsize=(8, 4)) +for i in range(cf.chart.depth): + cf_lvl = jft.RefinementField(**cf_kwargs, depth=i) + x = jnp.mgrid[tuple(slice(sz) for sz in cf_lvl.chart.shape)] + x = cf.chart.ind2rg(x, i)[0] + f_lvl = cf_lvl(xi[:i + 1]) + ax.step(x, f_lvl, alpha=0.7, where="mid", label=f"level {i}") +# ax.set_frame_on(False) +# ax.set_xticks([], []) +# ax.set_yticks([], []) +ax.legend() +fig.tight_layout() +fig.savefig("re_refine_field_layers.png", transparent=True) +plt.close() + + +# %% +def parametrized_kernel(xi, verbose=False): + scale = jnp.exp(-0.5 + 0.2 * xi["lat_scale"]) + cutoff = jnp.exp(4. + 1e-2 * xi["lat_cutoff"]) + # dof = jnp.exp(0.5 + 0.1 * xi["lat_dof"]) + # kernel = lambda r: xi["scale"] * jnp.exp(-(r / xi["cutoff"])**2) + if verbose: + print(f"{scale=}, {cutoff=}, {dof=}") + + return partial(matern_kernel, scale=scale, cutoff=cutoff, dof=dof) + + +def signal_response(xi): + return cf(xi["excitations"], parametrized_kernel(xi)) + + +n_std = 0.5 + +key = random.PRNGKey(45) +key, *key_splits = random.split(key, 4) + +xi_truth = jft.random_like(key_splits.pop(), cf.shapewithdtype) +d = cf(xi_truth, kernel) +d += n_std * random.normal(key_splits.pop(), shape=d.shape) + +xi_swd = { + "excitations": cf.shapewithdtype, + "lat_scale": jft.ShapeWithDtype(()), + "lat_cutoff": jft.ShapeWithDtype(()), +} +pos = 1e-4 * jft.Field(jft.random_like(key_splits.pop(), xi_swd)) + +n_mgvi_iterations = 15 +n_newton_iterations = 15 +n_samples = 2 +absdelta = 1e-5 + +nll = jft.Gaussian(d, noise_std_inv=lambda x: x / n_std) @ signal_response +ham = jft.StandardHamiltonian(nll) # + 0.5 * jft.norm(x, ord=2, ravel=True) +ham_vg = jax.jit(jft.mean_value_and_grad(ham)) +ham_metric = jax.jit(jft.mean_metric(ham.metric)) +MetricKL = jax.jit( + partial(jft.MetricKL, ham), + static_argnames=("n_samples", "mirror_samples", "linear_sampling_name") +) + +# %% +# Minimize the potential +for i in range(n_mgvi_iterations): + print(f"MGVI Iteration {i}", file=sys.stderr) + print("Sampling...", file=sys.stderr) + key, subkey = random.split(key, 2) + samples = MetricKL( + pos, + n_samples=n_samples, + key=subkey, + mirror_samples=True, + linear_sampling_kwargs={"absdelta": absdelta / 10.} + ) + + print("Minimizing...", file=sys.stderr) + opt_state = jft.minimize( + None, + x0=pos, + method="newton-cg", + options={ + "fun_and_grad": partial(ham_vg, primals_samples=samples), + "hessp": partial(ham_metric, primals_samples=samples), + "absdelta": absdelta, + "maxiter": n_newton_iterations + } + ) + pos = opt_state.x + msg = f"Post MGVI Iteration {i}: Energy {samples.at(pos).mean(ham):2.4e}" + print(msg, file=sys.stderr) + +# %% +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot(d, label="data") +ax.plot(cf(xi_truth, kernel), label="truth") +ax.plot(samples.at(pos).mean(signal_response), label="reconstruction") +ax.legend() +fig.tight_layout() +fig.savefig("re_refine_reconstruction.png", transparent=True) +plt.close() + +# %% +cf_bench = jft.RefinementField(shape0=(12, ), kernel=kernel, depth=15) +xi_wo = jft.random_like(random.PRNGKey(42), jft.Field(cf_bench.shapewithdtype)) +xi_w = jft.random_like( + random.PRNGKey(42), + jft.Field( + { + "excitations": cf_bench.shapewithdtype, + "lat_scale": jft.ShapeWithDtype(()), + "lat_cutoff": jft.ShapeWithDtype(()), + } + ) +) + + +def signal_response_bench(xi): + return cf_bench(xi["excitations"], parametrized_kernel(xi)) + + +d = signal_response_bench(0.5 * xi_w) +nll_wo_fwd = jft.Gaussian(d, noise_std_inv=lambda x: x / n_std) +ham_w = jft.StandardHamiltonian(nll_wo_fwd @ signal_response_bench) +ham_wo = jft.StandardHamiltonian(nll_wo_fwd @ cf_bench) + +# %% +all_backends = {"cpu"} +all_backends |= {jax.default_backend()} +for backend in all_backends: + device_kw = {"device": jax.devices(backend=backend)[0]} + device_put = partial(jax.device_put, **device_kw) + + cf_vag_bench = jax.jit(jax.value_and_grad(ham_w), **device_kw) + x = device_put(xi_w) + _ = jax.block_until_ready(cf_vag_bench(x)) + t = timeit(lambda: jax.block_until_ready(cf_vag_bench(x))) + ti, num = t.time, t.number + + msg = f"{backend.upper()} :: Shape {str(cf_bench.chart.shape):>16s} ({num:6d} loops) :: JAX w/ learnable {ti:4.2e}" + print(msg, file=sys.stderr) + + cf_vag_bench = jax.jit(jax.value_and_grad(ham_wo), **device_kw) + x = device_put(xi_wo) + _ = jax.block_until_ready(cf_vag_bench(x)) + t = timeit(lambda: jax.block_until_ready(cf_vag_bench(x))) + ti, num = t.time, t.number + + msg = f"{backend.upper()} :: Shape {str(cf_bench.chart.shape):>16s} ({num:6d} loops) :: JAX w/o learnable {ti:4.2e}" + print(msg, file=sys.stderr) diff --git a/setup.py b/setup.py index 015e37aeb3aea54778f2a0441e0bc13dc7e0aae4..2de928b7fb12c546fa020b72f4b2f830787c9e4f 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,8 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +from functools import reduce +import operator import os import site import sys @@ -30,6 +32,14 @@ with open("README.md") as f: long_description = f.read() description = "Library for signal inference algorithms that operate regardless of the underlying grids and their resolutions." +extras_require = { + "re": ("jax", ), + "native": ("ducc0", "finufft"), + "doc": ("sphinx", "pydata-sphinx-theme", "jupyter", "jupytext"), + "util": ("astropy", ), +} +extras_require["full"] = reduce(operator.add, extras_require.values()) + setup(name="nifty8", version=__version__, author="Martin Reinecke", @@ -48,6 +58,7 @@ setup(name="nifty8", license="GPLv3", setup_requires=['scipy>=1.4.1', 'numpy>=1.17'], install_requires=['scipy>=1.4.1', 'numpy>=1.17'], + extras_require=extras_require, python_requires='>=3.7', classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/src/__init__.py b/src/__init__.py index 8d0711b93e754980a15aa5138f3f18a376069154..9ac933d086e45c38ffaa16d8963bbfe76f0c557e 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -108,5 +108,11 @@ from .operator_tree_optimiser import optimise_operator from .ducc_dispatch import set_nthreads, nthreads +try: + from . import re + from . import nifty2jax +except ImportError: + pass + # We deliberately don't set __all__ here, because we don't want people to do a # "from nifty8 import *"; that would swamp the global namespace. diff --git a/src/domains/structured_domain.py b/src/domains/structured_domain.py index 3ab7d9d4bdce268311a4dbb115fc6a9b17ed1110..0c3ad3ac8461531320094afdd81895125fc1abfa 100644 --- a/src/domains/structured_domain.py +++ b/src/domains/structured_domain.py @@ -70,7 +70,7 @@ class StructuredDomain(Domain): Returns ------- - Field + :class:`nifty8.field.Field` An array containing the k vector lengths """ raise NotImplementedError diff --git a/src/extra.py b/src/extra.py index c609295d9ec7aaefadd29f76d902400f217d3099..1cffcd7c2718cb289a9d28e3dc782d8790f177d6 100644 --- a/src/extra.py +++ b/src/extra.py @@ -107,7 +107,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True, ---------- op : Operator Operator which shall be checked. - loc : Field or MultiField + loc : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` An Field or MultiField instance which has the same domain as op. The location at which the gradient is checked tol : float diff --git a/src/field.py b/src/field.py index dcf3c1230efd74a10d3e14ecf9d1997141cf8f97..144ce2320634767150f9fb22ed12589a969929be 100644 --- a/src/field.py +++ b/src/field.py @@ -53,6 +53,10 @@ class Field(Operator): if not isinstance(val, np.ndarray): if np.isscalar(val): val = np.broadcast_to(val, domain.shape) + elif np.shape(val) == domain.shape: + # If NumPy thinks the shapes are equal, attempt to convert to + # NumPy. This is especially helpful for JAX DeviceArrays. + val = np.asarray(val) else: raise TypeError("val must be of type numpy.ndarray") if domain.shape != val.shape: @@ -276,7 +280,7 @@ class Field(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` Returns ------- @@ -294,7 +298,7 @@ class Field(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` x must be defined on the same domain as `self`. spaces : None, int or tuple of int @@ -326,7 +330,7 @@ class Field(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` x must be defined on the same domain as `self`. Returns diff --git a/src/library/adjust_variances.py b/src/library/adjust_variances.py index 883d6cf5da87e574c346f4cca6719aab7339a6a5..5103261ea948d1c1f022033b87b5aa3115a84d6d 100644 --- a/src/library/adjust_variances.py +++ b/src/library/adjust_variances.py @@ -44,9 +44,9 @@ def make_adjust_variances_hamiltonian(a, xi : Operator Field Adapter selecting a part of position. xi is desired to be a Gaussian white Field. - position : Field, MultiField + position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` Contains the initial values for the operators a and xi, to be adjusted - samples : Field, MultiField + samples : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` Residual samples of position. scaling : Float Optional rescaling of the Likelihood. @@ -55,7 +55,7 @@ def make_adjust_variances_hamiltonian(a, Returns ------- - StandardHamiltonian + :class:`nifty8.operators.energy_operators.StandardHamiltonian` A Hamiltonian that can be used for further minimization. """ @@ -91,7 +91,7 @@ def do_adjust_variances(position, A, minimizer, xi_key='xi', samples=[]): Parameters ---------- - position : Field, MultiField + position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` Contains the initial values for amplitude_operator and the key xi_key, to be adjusted. A : Operator @@ -101,7 +101,7 @@ def do_adjust_variances(position, A, minimizer, xi_key='xi', samples=[]): xi_key : String Key of the Field containing undesired variations. This Field is contained in position. - samples : Field, MultiField, optional + samples : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField`, optional Residual samples of position. If samples are supplied then phi remains only approximately constant. Default: []. diff --git a/src/library/correlated_fields.py b/src/library/correlated_fields.py index 6b17481788288b247e5de4f356e4ed0d4eb500a2..9b962c8df7449b9e41d3a644afd9df504653eaf4 100644 --- a/src/library/correlated_fields.py +++ b/src/library/correlated_fields.py @@ -27,6 +27,7 @@ import numpy as np from .. import utilities from ..domain_tuple import DomainTuple from ..domains.power_space import PowerSpace +from ..domains.rg_space import RGSpace from ..domains.unstructured_domain import UnstructuredDomain from ..field import Field from ..logger import logger @@ -38,7 +39,6 @@ from ..operators.distributors import PowerDistributor from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.linear_operator import LinearOperator -from ..operators.mask_operator import MaskOperator from ..operators.normal_operators import LognormalTransform, NormalTransform from ..operators.operator import Operator from ..operators.simple_linear_operators import VdotOperator, ducktape @@ -453,6 +453,16 @@ class CorrelatedFieldMaker: self._prefix = prefix self._total_N = total_N + try: + from .. import re as jft + + if total_N != 0: + warn(f"unable to add JAX operator for total_N={total_N}") + raise ImportError("short-circuit JAX init") + self._jax_cfm = jft.CorrelatedFieldMaker(prefix=prefix) + except ImportError: + self._jax_cfm = None + def add_fluctuations(self, target_subdomain, fluctuations, @@ -566,6 +576,25 @@ class CorrelatedFieldMaker: target_subdomain[-1].total_volume, pre + 'spectrum', dofdex) + is_rg = all(isinstance(dom, RGSpace) for dom in target_subdomain) + if self._jax_cfm is not None and (len(dofdex) > 0 or index or not is_rg): + warn(f"unable to add JAX operator for {target_subdomain}") + self._jax_cfm = None + if self._jax_cfm is not None: + dists = tuple(e for di in target_subdomain for e in di.distances) + self._jax_cfm.add_fluctuations( + shape=target_subdomain.shape, + distances=dists, + fluctuations=fluctuations, + loglogavgslope=loglogavgslope, + flexibility=flexibility, + asperity=asperity, + prefix=str(prefix), + harmonic_domain_type="fourier", + non_parametric_kind="power", + ) + amp._jax_expr = self._jax_cfm.fluctuations[-1] + if index is not None: self._a.insert(index, amp) self._target_subdomains.insert(index, target_subdomain) @@ -652,6 +681,10 @@ class CorrelatedFieldMaker: amp = _AmplitudeMatern(pow_spc, scale, cutoff, loglogslope, totvol) + if self._jax_cfm is not None: + warn(f"unable to add JAX operator for Matern fluctuations") + self._jax_cfm = None + self._a.append(amp) self._target_subdomains.append(target_subdomain) @@ -684,12 +717,15 @@ class CorrelatedFieldMaker: logger.warning("Overwriting the previous mean offset and zero-mode") self._offset_mean = offset_mean + jax_offset_std = offset_std if offset_std is None: self._azm = 0. elif np.isscalar(offset_std) and offset_std == 1.: self._azm = 1. + jax_offset_std = lambda _: 1. elif isinstance(offset_std, Operator): self._azm = offset_std + jax_offset_std = offset_std.jax_expr else: if dofdex is None: dofdex = np.full(self._total_N, 0) @@ -710,6 +746,21 @@ class CorrelatedFieldMaker: zm = _Distributor(dofdex, zm.target, UnstructuredDomain(self._total_N)) @ zm self._azm = zm + if self._jax_cfm is not None and dofdex is not None and len(dofdex) > 0: + warn(f"unable to add JAX operator for dofdex={dofdex}") + self._jax_cfm = None + if self._jax_cfm is not None: + try: + self._jax_cfm.set_amplitude_total_offset( + offset_mean=offset_mean, offset_std=jax_offset_std + ) + if not isinstance(self._azm, float): + self._azm._jax_expr = self._jax_cfm.azm + except TypeError as e: + self._jax_cfm = None + if isinstance(e, TypeError): + warn(f"no JAX operator for this configuration;\n{e}") + def finalize(self, prior_info=100): """Finishes model construction process and returns the constructed operator. @@ -766,6 +817,10 @@ class CorrelatedFieldMaker: offset = float(offset) op = Adder(full(op.target, offset)) @ op self.statistics_summary(prior_info) + + if self._jax_cfm is not None: + cf, _ = self._jax_cfm.finalize() + op._jax_expr = cf return op def statistics_summary(self, prior_info): @@ -829,8 +884,13 @@ class CorrelatedFieldMaker: elif self.azm == 1: return self.fluctuations + if self._jax_cfm: + normed_amps_jax = self._jax_cfm.get_normalized_amplitudes() + else: + normed_amps_jax = (None, ) * len(self._a) + normal_amp = [] - for amp in self._a: + for amp, na_jax in zip(self._a, normed_amps_jax): a_target = amp.target a_space = 0 if not hasattr(amp, "_space") else amp._space a_pp = amp.target[a_space] @@ -852,7 +912,9 @@ class CorrelatedFieldMaker: zm_normalization = zm_unmask @ ( zm_mask @ azm_expander(self.azm.ptw("reciprocal")) ) - normal_amp.append(zm_normalization * amp) + na = zm_normalization * amp + na._jax_expr = na_jax + normal_amp.append(na) return tuple(normal_amp) @property @@ -865,12 +927,16 @@ class CorrelatedFieldMaker: normal_amp = self.get_normalized_amplitudes()[0] if np.isscalar(self.azm): - return normal_amp + na = normal_amp else: expand = ContractionOperator( normal_amp.target, len(normal_amp.target) - 1 ).adjoint - return normal_amp * (expand @ self.azm) + na = normal_amp * (expand @ self.azm) + + if self._jax_cfm: + na._jax_expr = self._jax_cfm.amplitude + return na @property def power_spectrum(self): diff --git a/src/library/correlated_fields_simple.py b/src/library/correlated_fields_simple.py index ca754dd243d231456df5f7ed5aa2152355d9961c..57d53e39f959fed5dc106cb50dbfc6c9225aac84 100644 --- a/src/library/correlated_fields_simple.py +++ b/src/library/correlated_fields_simple.py @@ -17,6 +17,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np +from warnings import warn from ..domain_tuple import DomainTuple from ..domains.power_space import PowerSpace @@ -130,4 +131,38 @@ def SimpleCorrelatedField( op = Adder(full(op.target, float(offset_mean))) @ op op.amplitude = a op.power_spectrum = a**2 + + try: + from .. import re as jft + from .. import RGSpace + + if not all(isinstance(dom, RGSpace) for dom in op.target): + warn(f"unable to add JAX operator for {op.target!r}") + raise ImportError("short-circuit JAX init") + + dists = tuple(e for di in op.target for e in di.distances) + cfm = jft.CorrelatedFieldMaker(prefix=prefix) + cfm.add_fluctuations( + shape=op.target.shape, + distances=dists, + fluctuations=fluctuations, + loglogavgslope=loglogavgslope, + flexibility=flexibility, + asperity=asperity, + prefix="", + harmonic_domain_type="fourier", + non_parametric_kind="power", + ) + cfm.set_amplitude_total_offset( + offset_mean=offset_mean, offset_std=offset_std + ) + cf, _ = cfm.finalize() + + op._jax_expr = cf + op.amplitude._jax_expr = cfm.amplitude + op.power_spectrum._jax_expr = cfm.power_spectrum + except (ImportError, TypeError) as e: + if isinstance(e, TypeError): + warn(f"no JAX operator for this configuration;\n{e}") + return op diff --git a/src/library/special_distributions.py b/src/library/special_distributions.py index ad12218fb366bd35de9381de4e3c75347c27c039..1a0bccd1025960e1054fe2bea547da4b89e83d48 100644 --- a/src/library/special_distributions.py +++ b/src/library/special_distributions.py @@ -78,6 +78,23 @@ class _InterpolationOperator(Operator): self._deriv = self._interpolator.derivative() self._inv_table_func = inv_table_func + try: + from jax import numpy as jnp + + def jax_expr(x): + res = jnp.interp(x, self._xs, self._table) + if inv_table_func is not None: + ve = ( + "can not translate arbitrary inverse" + f" table function {inv_table_func!r}" + ) + raise ValueError(ve) + return res + + self._jax_expr = jax_expr + except ImportError: + self._jax_expr = None + def apply(self, x): self._check_input(x) lin = x.jac is not None @@ -118,7 +135,7 @@ class InverseGammaOperator(Operator): time the domain and the target of the operator. alpha : float The alpha-parameter of the inverse-gamma distribution. - q : float or Field + q : float or :class:`nifty8.field.Field` The q-parameter of the inverse-gamma distribution. mode: float The mode of the inverse-gamma distribution. @@ -155,6 +172,14 @@ class InverseGammaOperator(Operator): op = makeOp(self._q) @ op self._op = op + try: + from ..re.stats_distributions import invgamma_prior + + q_val = self._q.val if isinstance(self._q, Field) else self._q + self._jax_expr = invgamma_prior(float(self._alpha), q_val) + except ImportError: + self._jax_expr = None + def apply(self, x): return self._op(x) diff --git a/src/library/variational_models.py b/src/library/variational_models.py index ac64ce5310bfb215bae9f9cff2b12ed0a08b5c20..82ed4552b4d2a376af37f21c3b67b5fb664241cb 100644 --- a/src/library/variational_models.py +++ b/src/library/variational_models.py @@ -50,7 +50,7 @@ class MeanFieldVI: Parameters ---------- - position : Field + position : :class:`nifty8.field.Field` The initial estimate of the approximate mean parameter. hamiltonian : Energy Hamiltonian of the approximated probability distribution. @@ -62,7 +62,7 @@ class MeanFieldVI: doubles. Mirroring samples stabilizes the KL estimate as extreme sample variation is counterbalanced. Since it improves stability in many cases, it is recommended to set `mirror_samples` to `True`. - initial_sig : positive Field or positive float + initial_sig : positive :class:`nifty8.field.Field` or positive float The initial estimate of the standard deviation. comm : MPI communicator or None If not None, samples will be distributed as evenly as possible across @@ -140,7 +140,7 @@ class FullCovarianceVI: Parameters ---------- - position : Field + position : :class:`nifty8.field.Field` The initial estimate of the approximate mean parameter. hamiltonian : Energy Hamiltonian of the approximated probability distribution. diff --git a/src/linearization.py b/src/linearization.py index f1a590c0608c8ffe09c4a2ddd8f9ec7a9350f037..75c4ccf894bf96d48b809430260b59bee922bf28 100644 --- a/src/linearization.py +++ b/src/linearization.py @@ -29,7 +29,7 @@ class Linearization(Operator): Parameters ---------- - val : Field or MultiField + val : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` The value of the operator application. jac : LinearOperator The Jacobian. @@ -52,7 +52,7 @@ class Linearization(Operator): Parameters ---------- - val : Field or MultiField + val : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` the value of the operator application jac : LinearOperator the Jacobian @@ -83,7 +83,7 @@ class Linearization(Operator): @property def val(self): - """Field or MultiField : the value""" + """:class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` : the value""" return self._val @property @@ -93,7 +93,7 @@ class Linearization(Operator): @property def gradient(self): - """Field or MultiField : the gradient + """:class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` : the gradient Notes ----- @@ -198,7 +198,7 @@ class Linearization(Operator): Parameters ---------- - other : Field or MultiField or Linearization + other : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` or Linearization Returns ------- @@ -223,7 +223,7 @@ class Linearization(Operator): Parameters ---------- - other : Field or MultiField or Linearization + other : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` or Linearization Returns ------- @@ -292,7 +292,7 @@ class Linearization(Operator): Parameters ---------- - field : Field or Multifield + field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` the field to be converted want_metric : bool If True, the metric will be computed for other Linearizations @@ -313,7 +313,7 @@ class Linearization(Operator): Parameters ---------- - field : Field or Multifield + field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` the field to be converted want_metric : bool If True, the metric will be computed for other Linearizations @@ -338,7 +338,7 @@ class Linearization(Operator): Parameters ---------- - field : Field or Multifield + field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` the field to be converted want_metric : bool If True, the metric will be computed for other Linearizations @@ -367,7 +367,7 @@ class Linearization(Operator): Parameters ---------- - field : Multifield + field ::class:`nifty8.multi_field.MultiField` the field to be converted constants : list of string the MultiField components for which the Jacobian should be diff --git a/src/minimization/descent_minimizers.py b/src/minimization/descent_minimizers.py index 4c50ed0f644560b14adee3eed37273fc2331bf8f..9280392d2b4aa160b4020dfd445bcbb6834b6701 100644 --- a/src/minimization/descent_minimizers.py +++ b/src/minimization/descent_minimizers.py @@ -125,7 +125,7 @@ class DescentMinimizer(Minimizer): Returns ------- - Field + :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` The descent direction. """ raise NotImplementedError @@ -316,9 +316,9 @@ class _InformationStore: ---------- max_history_length : int Maximum number of stored past updates. - x0 : Field + x0 : :class:`nifty8.field.Field` Initial position in variable space. - gradient : Field + gradient : :class:`nifty8.field.Field` Gradient at position x0. Attributes @@ -329,9 +329,9 @@ class _InformationStore: Circular buffer of past position differences, which are Fields. y : List Circular buffer of past gradient differences, which are Fields. - last_x : Field + last_x : :class:`nifty8.field.Field` Latest position in variable space. - last_gradient : Field + last_gradient : :class:`nifty8.field.Field` Gradient at latest position. k : int Number of updates that have taken place diff --git a/src/minimization/energy.py b/src/minimization/energy.py index 5981612310224b05df8dc1a9995c7f85c90e9ef2..1a8c7922403f0b58c2c07bd4c72c1832e975ee65 100644 --- a/src/minimization/energy.py +++ b/src/minimization/energy.py @@ -26,7 +26,7 @@ class Energy(metaclass=NiftyMeta): Parameters ---------- - position : Field + position : :class:`nifty8.field.Field` The input parameter of the scalar function. Notes @@ -51,7 +51,7 @@ class Energy(metaclass=NiftyMeta): Parameters ---------- - position : Field + position : :class:`nifty8.field.Field` Location in parameter space for the new Energy object. Returns @@ -64,7 +64,7 @@ class Energy(metaclass=NiftyMeta): @property def position(self): """ - Field : selected location in parameter space. + field : selected location in parameter space. The Field location in parameter space where value, gradient and metric are evaluated. @@ -83,7 +83,7 @@ class Energy(metaclass=NiftyMeta): @property def gradient(self): """ - Field : The gradient at given `position`. + field : The gradient at given `position`. """ raise NotImplementedError @@ -109,12 +109,12 @@ class Energy(metaclass=NiftyMeta): """ Parameters ---------- - x: Field or MultiField + x : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` Argument for the metric operator Returns ------- - Field or MultiField: + :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` Output of the metric operator """ raise NotImplementedError @@ -124,7 +124,7 @@ class Energy(metaclass=NiftyMeta): Parameters ---------- - direction : Field + direction : :class:`nifty8.field.Field` the search direction Returns diff --git a/src/minimization/energy_adapter.py b/src/minimization/energy_adapter.py index 1f265c3079c8fe515c7cc42de2270a43549cb20b..afb2f15add1ce314ec9ccbb75cf7a0625fad1a46 100644 --- a/src/minimization/energy_adapter.py +++ b/src/minimization/energy_adapter.py @@ -33,16 +33,16 @@ class EnergyAdapter(Energy): Parameters ----------- - position: Field or MultiField + position : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` The position where the minimization process is started. - op: EnergyOperator + op : EnergyOperator The expression computing the energy from the input data. - constants: list of strings + constants : list of strings The component names of the operator's input domain which are assumed to be constant during the minimization process. If the operator's input domain is not a MultiField, this must be empty. Default: []. - want_metric: bool + want_metric : bool If True, the class will provide a `metric` property. This should only be enabled if it is required, because it will most likely consume additional resources. Default: False. @@ -170,7 +170,7 @@ class StochasticEnergyAdapter(Energy): Parameters ---------- - position : MultiField + position : :class:`nifty8.multi_field.MultiField` Values of the optimization parameters op : Operator The objective function of the optimization problem. Must have a diff --git a/src/minimization/kl_energies.py b/src/minimization/kl_energies.py index 4889dba9aebfe71eaa23d736c968e1c7f44a8401..06e4402731cf5bc978dac4f3e80ca3ea436b5cb0 100644 --- a/src/minimization/kl_energies.py +++ b/src/minimization/kl_energies.py @@ -51,7 +51,7 @@ def _reduce_by_keys(field, operator, keys): Parameters ---------- - field : Field or MultiField + field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` Potentially partially constant input field. operator : Operator Operator into which `field` is partially inserted. @@ -183,9 +183,9 @@ def SampledKLEnergy(position, hamiltonian, n_samples, minimizer_sampling, Parameters ---------- - position : Field + position : :class:`nifty8.field.Field` Expansion point of the coordinate transformation. - hamiltonian : StandardHamiltonian + hamiltonian : :class:`nifty8.operators.energy_operators.StandardHamiltonian` Hamiltonian of the approximated probability distribution. n_samples : integer Number of samples used to stochastically estimate the KL. diff --git a/src/minimization/line_search.py b/src/minimization/line_search.py index 127df9f7d7720e31a2ff0de69c756b5c06eaa5c2..f30f753ea4fb1b1c17019d1c7b04271140d9e0d6 100644 --- a/src/minimization/line_search.py +++ b/src/minimization/line_search.py @@ -34,7 +34,7 @@ class LineEnergy: self.energy.position = zero_point + line_position*line_direction energy : Energy The Energy object which will be evaluated along the given direction. - line_direction : Field + line_direction : :class:`nifty8.field.Field` Direction used for line evaluation. Does not have to be normalized. offset : float *optional* Indirectly defines the zero point of the line via the equation @@ -156,7 +156,7 @@ class LineSearch(metaclass=NiftyMeta): energy : Energy Energy object from which we will calculate the energy and the gradient at a specific point. - pk : Field + pk : :class:`nifty8.field.Field` Vector pointing into the search direction. f_k_minus_1 : float, optional Value of the fuction (which is being minimized) at the k-1 diff --git a/src/minimization/optimize_kl.py b/src/minimization/optimize_kl.py index 0947c3516041ab0d1a92a2651330b78ef5983834..4a9446637673c3cc230a55c900301ec16943b019 100644 --- a/src/minimization/optimize_kl.py +++ b/src/minimization/optimize_kl.py @@ -138,14 +138,14 @@ def optimize_kl(likelihood_energy, output_directory : str or None Directory in which all output files are saved. If None, no output is stored. Default: "nifty_optimize_kl_output". - initial_position : Field, MultiField or None + initial_position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` or None Position in the definition space of `likelihood_energy` from which the optimization is started. If `None`, it starts at a random, normal distributed position with standard deviation 0.1. Default: None. initial_index : int Initial index that is used to enumerate the output files. May be used if `optimize_kl` is called multiple times. Default: 0. - ground_truth_position : Field, MultiField or None + ground_truth_position : :class:`nifty8.field.Field`, :class:`nifty8.multi_field.MultiField` or None Position in latent space that represents the ground truth. Used only in plotting. May be useful for validating algorithms. comm : MPI communicator or None @@ -195,7 +195,7 @@ def optimize_kl(likelihood_energy, ------- sl : SampleList - mean : Field or MultiField (optional) + mean : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` (optional) Note ---- diff --git a/src/minimization/quadratic_energy.py b/src/minimization/quadratic_energy.py index f921caa30f35a258275f9ecc161c186a08440c14..a63d7ee22960a0e4cc9619d2f9db8c09b281cdb8 100644 --- a/src/minimization/quadratic_energy.py +++ b/src/minimization/quadratic_energy.py @@ -50,9 +50,9 @@ class QuadraticEnergy(Energy): Parameters ---------- - position : Field + position : :class:`nifty8.field.Field` Location in parameter space for the new Energy object. - grad : Field + grad : :class:`nifty8.field.Field` Energy gradient at the new position. Returns diff --git a/src/minimization/sample_list.py b/src/minimization/sample_list.py index 0167586a040d1684cdce3e236c4e30667c5423c9..b709c77a97500d6d5d148bebf68ad8821f7ab23e 100644 --- a/src/minimization/sample_list.py +++ b/src/minimization/sample_list.py @@ -453,9 +453,9 @@ class ResidualSampleList(SampleListBase): Parameters ---------- - mean : Field or MultiField + mean : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` Mean of the sample list. - residuals : list of Field or list of MultiField + residuals : list of :class:`nifty8.field.Field` or list of :class:`nifty8.multi_field.MultiField` List of residuals from the mean. If it is a list of `MultiField`, the domain of the residuals can be a subdomain of the domain of mean. This results in adding just a zero in respective `MultiField` @@ -547,7 +547,7 @@ class SampleList(SampleListBase): Parameters ---------- - samples : list of Field or list of MultiField + samples : list of :class:`nifty8.field.Field` or list of :class:`nifty8.multi_field.MultiField` List of samples. comm : MPI communicator or None If not `None`, samples can be gathered across multiple MPI tasks. If diff --git a/src/multi_field.py b/src/multi_field.py index ed6f05d9da973d1ebc318a8eb89e51aad9f801b0..e87573901efbda77d9c0594d96cab8e1b849781f 100644 --- a/src/multi_field.py +++ b/src/multi_field.py @@ -262,7 +262,7 @@ class MultiField(Operator): Parameters ---------- - other : MultiField + other : :class:`nifty8.multi_field.MultiField` the partner Field Returns @@ -281,7 +281,7 @@ class MultiField(Operator): Parameters ---------- - fields : iterable of MultiFields + fields : iterable of :class:`nifty8.multi_field.MultiField` The set of input fields. Their domains need not be identical. domain : MultiDomain or None If supplied, this will be the domain of the resulting field. @@ -308,7 +308,7 @@ class MultiField(Operator): Parameters ---------- - other : MultiField + other : :class:`nifty8.multi_field.MultiField` the partner Field neg : bool or dict if True, the partner field is subtracted, otherwise added diff --git a/src/nifty2jax.py b/src/nifty2jax.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb3910370b7bf82211c567973b4c3acc9fc1131 --- /dev/null +++ b/src/nifty2jax.py @@ -0,0 +1,149 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial, reduce +import operator +from typing import Any, Callable, Optional, Union +from warnings import warn + +from jax.tree_util import register_pytree_node_class + +from . import re as jft +from .domain_tuple import DomainTuple +from .field import Field +from .multi_domain import MultiDomain +from .multi_field import MultiField +from .operators.operator import Operator +from .sugar import makeField + + +def spaces_to_axes(domain, spaces): + """Converts spaces in a domain to axes of the underlying NumPy array.""" + if spaces is None: + return None + + domain = DomainTuple.make(domain) + axes = tuple(domain.axes[sp_index] for sp_index in spaces) + axes = reduce(operator.add, axes) if len(axes) > 0 else axes + return axes + + +def shapewithdtype_from_domain(domain, dtype): + if isinstance(dtype, dict): + dtp_fallback = float # Fallback to `float` for unspecified keys + k2dtp = dtype + else: + dtp_fallback = dtype + k2dtp = {} + + if isinstance(domain, MultiDomain): + parameter_tree = {} + for k, dom in domain.items(): + parameter_tree[k] = jft.ShapeWithDtype( + dom.shape, k2dtp.get(k, dtp_fallback) + ) + elif isinstance(domain, DomainTuple): + parameter_tree = jft.ShapeWithDtype(domain.shape, dtype) + else: + raise TypeError(f"incompatible domain {domain!r}") + return parameter_tree + + +@register_pytree_node_class +class Model(jft.Field): + """Modified field class with an additional call method taking itself as + input. + """ + def __init__(self, apply: Optional[Callable], val, domain=None, flags=None): + """Instantiates a modified field with an accompanying callable. + + Parameters + ---------- + apply : callable + Method acting on `val`. + val : object + Arbitrary, flatten-able objects. + domain : dict or None, optional + Domain of the field, e.g. with description of modes and volume. + flags : set, str or None, optional + Capabilities and constraints of the field. + """ + super().__init__(val, domain, flags) + self._apply = apply + + def tree_flatten(self): + return ((self._val, ), (self._apply, self._domain, self._flags)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + aux_data[0], *children, domain=aux_data[1], flags=aux_data[2] + ) + + def __call__(self, *args, **kwargs): + if self._apply is None: + nie = "no `apply` method specified; behaving like field" + raise NotImplementedError(nie) + return self._apply(*args, **kwargs) + + @property + def field(self): + return jft.Field(self.val, domain=self.domain, flags=self.flags) + + def __str__(self): + s = f"Model(\n{self._apply},\n{self.val}" + if self._domain: + s += f",\ndomain={self._domain}" + if self._flags: + s += f",\nflags={self._flags}" + s += ")" + return s + + def __repr__(self): + s = f"Model(\n{self._apply!r},\n{self.val!r}" + if self._domain: + s += f",\ndomain={self._domain!r}" + if self._flags: + s += f",\nflags={self._flags!r}" + s += ")" + return s + + +def wrap_nifty_call(op, target_dtype=float) -> Callable[[Any], jft.Field]: + from jax.experimental.host_callback import call + + if callable(op.jax_expr): + warn("wrapping operator that has a callable `.jax_expr`") + + def pack_unpack_call(x): + x = makeField(op.domain, x) + return op(x).val + + # TODO: define custom JVP and VJP rules + pt = shapewithdtype_from_domain(op.target, target_dtype) + hcb_call = partial(call, pack_unpack_call, result_shape=pt) + + def wrapped_call(x) -> jft.Field: + return jft.Field(hcb_call(x)) + + return wrapped_call + + +def convert(nifty_obj: Union[Operator,DomainTuple,MultiDomain], dtype=float) -> Model: + if not isinstance(nifty_obj, (Operator, DomainTuple, MultiDomain)): + raise TypeError(f"invalid input type {type(nifty_obj)!r}") + + if isinstance(nifty_obj, (Field, MultiField)): + expr = None + parameter_tree = jft.Field(nifty_obj.val) + elif isinstance(nifty_obj, (DomainTuple, MultiDomain)): + expr = None + parameter_tree = shapewithdtype_from_domain(nifty_obj, dtype) + else: + expr = nifty_obj.jax_expr + parameter_tree = shapewithdtype_from_domain(nifty_obj.domain, dtype) + if not callable(expr): + # TODO: implement conversion via host_callback and custom_vjp + raise NotImplementedError("Sorry, not yet done :(") + + return Model(expr, parameter_tree) diff --git a/src/operators/adder.py b/src/operators/adder.py index 6e9f0aece191ca9ead3454879a160476e97883fb..58bee1c735b8e01ee06ee0acfcace9cdc2e055e0 100644 --- a/src/operators/adder.py +++ b/src/operators/adder.py @@ -15,6 +15,8 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +from operator import add, sub + import numpy as np from ..field import Field @@ -28,7 +30,7 @@ class Adder(Operator): Parameters ---------- - a : Field or MultiField or Scalar + a : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` or Scalar The field by which the input is shifted. """ def __init__(self, a, neg=False, domain=None): @@ -42,6 +44,24 @@ class Adder(Operator): self._domain = self._target = dom self._neg = bool(neg) + try: + from ..re import Field as ReField + from jax.tree_util import tree_map + + a_j = ReField(a.val) if isinstance(a, (Field, MultiField)) else a + + def jax_expr(x): + # Preserve the input type + if not isinstance(x, ReField): + a_astype_x = a_j.val if isinstance(a_j, ReField) else a_j + else: + a_astype_x = a_j + return tree_map(sub if neg else add, x, a_astype_x) + + self._jax_expr = jax_expr + except ImportError: + self._jax_expr = None + def apply(self, x): self._check_input(x) if self._neg: diff --git a/src/operators/chain_operator.py b/src/operators/chain_operator.py index 62e7b8bf6eea18277eb687224f2f701c0a5db8b0..52f06fe2ef86b5814c84f35a776ea92317a0fbcf 100644 --- a/src/operators/chain_operator.py +++ b/src/operators/chain_operator.py @@ -40,6 +40,17 @@ class ChainOperator(LinearOperator): self._domain = self._ops[-1].domain self._target = self._ops[0].target + if all(callable(op.jax_expr) for op in ops): + + def joined_jax_op(x): + for op in reversed(ops): + x = op.jax_expr(x) + return x + + self._jax_expr = joined_jax_op + else: + self._jax_expr = None + @staticmethod def simplify(ops): # verify domains diff --git a/src/operators/contraction_operator.py b/src/operators/contraction_operator.py index 9eb10770752bf32d01444a7be403a5beb8a358c6..2db2343597a6125176d3d3c067ff1f4b92f50f87 100644 --- a/src/operators/contraction_operator.py +++ b/src/operators/contraction_operator.py @@ -15,6 +15,8 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +from functools import partial + import numpy as np from .. import utilities @@ -51,6 +53,35 @@ class ContractionOperator(LinearOperator): self._power = power self._capability = self.TIMES | self.ADJOINT_TIMES + try: + from jax import numpy as jnp + from jax.tree_util import tree_map + from ..nifty2jax import spaces_to_axes + + fct = jnp.array(1.) + wgt = jnp.array(1.) + if self._power != 0: + for ind in self._spaces: + wgt_spc = self._domain[ind].dvol + if np.isscalar(wgt_spc): + fct *= wgt_spc + else: + new_shape = np.ones(len(self._domain.shape), dtype=np.int64) + new_shape[self._domain.axes[ind][0]: + self._domain.axes[ind][-1]+1] = wgt_spc.shape + wgt *= wgt_spc.reshape(new_shape)**power + fct = fct**power + + def weighted_space_sum(x): + if self._power != 0: + x = fct * wgt * x + axes = spaces_to_axes(self._domain, self._spaces) + return tree_map(partial(jnp.sum, axis=axes), x) + + self._jax_expr = weighted_space_sum + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) if mode == self.ADJOINT_TIMES: diff --git a/src/operators/diagonal_operator.py b/src/operators/diagonal_operator.py index caeefca05c44a746f72a1d22ea23a298f71d499a..22978268397b004a984f96263f5d0cf3b8d102a8 100644 --- a/src/operators/diagonal_operator.py +++ b/src/operators/diagonal_operator.py @@ -16,6 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np +from functools import partial +from operator import mul from .. import utilities from ..domain_tuple import DomainTuple @@ -32,7 +34,7 @@ class DiagonalOperator(EndomorphicOperator): Parameters ---------- - diagonal : Field + diagonal : :class:`nifty8.field.Field` The diagonal entries of the operator. domain : Domain, tuple of Domain or DomainTuple, optional The domain on which the Operator's input Field is defined. @@ -92,6 +94,8 @@ class DiagonalOperator(EndomorphicOperator): self._ldiag = diagonal.val self._fill_rest() + self._jax_expr = partial(mul, self._ldiag) + def _fill_rest(self): self._ldiag.flags.writeable = False self._complex = utilities.iscomplextype(self._ldiag.dtype) @@ -109,6 +113,9 @@ class DiagonalOperator(EndomorphicOperator): res._spaces = tuple(set(self._spaces) | set(spc)) res._ldiag = np.array(ldiag) res._fill_rest() + + res._jax_expr = partial(mul, res._ldiag) + return res def _scale(self, fct): diff --git a/src/operators/einsum.py b/src/operators/einsum.py index 03aa39989fb47ca3175491333690eba594b280be..8486a186635ae8c697635aa6f8ebd28e4c8808b8 100644 --- a/src/operators/einsum.py +++ b/src/operators/einsum.py @@ -174,7 +174,7 @@ class LinearEinsum(LinearOperator): ---------- domain : Domain, DomainTuple or tuple of Domain The operator's input domain. - mf : MultiField + mf : :class:`nifty8.multi_field.MultiField` The first part of the left-hand side of the einsum. subscripts : str The subscripts which is passed to einsum. Everything before the very diff --git a/src/operators/endomorphic_operator.py b/src/operators/endomorphic_operator.py index 6adbfd1bf413b235a817378089731f5519e0c553..0e3fad962e5e540078c1cb58df63864043f7c6f7 100644 --- a/src/operators/endomorphic_operator.py +++ b/src/operators/endomorphic_operator.py @@ -43,7 +43,7 @@ class EndomorphicOperator(LinearOperator): Returns ------- - Field or MultiField + :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` A sample from the Gaussian of given covariance. """ raise NotImplementedError diff --git a/src/operators/energy_operators.py b/src/operators/energy_operators.py index 4d2b3b8c78c77743c37d3362bf6369009eaa99c2..6d9f11c15300955c5be0bf6438a3fb0075a14132 100644 --- a/src/operators/energy_operators.py +++ b/src/operators/energy_operators.py @@ -488,7 +488,7 @@ class GaussianEnergy(LikelihoodEnergyOperator): Parameters ---------- - data : Field or None + data : :class:`nifty8.field.Field` or None Observed data of the Gaussian likelihood. If `inverse_covariance` is `None`, the `dtype` of `data` is used for sampling. Default is 0. @@ -597,7 +597,7 @@ class PoissonianEnergy(LikelihoodEnergyOperator): Parameters ---------- - d : Field + d : :class:`nifty8.field.Field` Data field with counts. Needs to have integer dtype and all field values need to be non-negative. """ @@ -635,14 +635,14 @@ class InverseGammaEnergy(LikelihoodEnergyOperator): \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i This is the likelihood for the variance :math:`x=S_k` given data - :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have - the covariance :math:`S_k`. + :math:`\\beta = 0.5 |s_k|^2` where the :class:`nifty8.field.Field` + :math:`s` is known to have the covariance :math:`S_k`. Parameters ---------- - beta : Field + beta : :class:`nifty8.field.Field` beta parameter of the inverse gamma distribution - alpha : Scalar, Field, optional + alpha : Scalar, :class:`nifty8.field.Field`, optional alpha parameter of the inverse gamma distribution """ @@ -694,7 +694,7 @@ class StudentTEnergy(LikelihoodEnergyOperator): ---------- domain : `Domain` or `DomainTuple` Domain of the operator - theta : Scalar or Field + theta : Scalar or :class:`nifty8.field.Field` Degree of freedom parameter for the student t distribution """ @@ -733,7 +733,7 @@ class BernoulliEnergy(LikelihoodEnergyOperator): Parameters ---------- - d : Field + d : :class:`nifty8.field.Field` Data field with events (1) or non-events (0). """ @@ -838,7 +838,7 @@ class AveragedEnergy(EnergyOperator): ---------- h: Hamiltonian The energy to be averaged. - res_samples : iterable of Fields + res_samples : iterable of :class:`nifty8.field.Field` Set of residual sample points to be added to mean field for approximate estimation of the KL. diff --git a/src/operators/harmonic_operators.py b/src/operators/harmonic_operators.py index 39d4b31bead9cdfa6b202665fbee32a29209cd76..1895f8cdf6cc599923cf2286a829eec38892bf7f 100644 --- a/src/operators/harmonic_operators.py +++ b/src/operators/harmonic_operators.py @@ -16,6 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np +from functools import partial from .. import utilities from ..domain_tuple import DomainTuple @@ -71,6 +72,36 @@ class FFTOperator(LinearOperator): adom.check_codomain(target) target.check_codomain(adom) + try: + from jax.numpy import fft as jfft + + axes = self.domain.axes[self._space] + + def jax_expr(x, inverse=False): + if inverse: + if self.domain[self._space].harmonic: + func = jfft.fftn + fct = 1. + else: + func = jfft.ifftn + fct = self.domain[self._space].size + fct *= self.target[self._space].scalar_dvol + else: + if self.domain[self._space].harmonic: + func = jfft.ifftn + fct = self.domain[self._space].size + else: + func = jfft.fftn + fct = 1. + fct *= self.domain[self._space].scalar_dvol + return fct * func(x, axes=axes) if fct != 1 else func(x, axes=axes) + + self._jax_expr = jax_expr + self._jax_expr_inv = partial(jax_expr, inverse=True) + + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) ncells = x.domain[self._space].size @@ -138,6 +169,33 @@ class HartleyOperator(LinearOperator): adom.check_codomain(target) target.check_codomain(adom) + try: + from jax.numpy import fft as jfft + + axes = self.domain.axes[self._space] + + def hartley(a): + ft = jfft.fftn(a, axes=axes) + return ft.real + ft.imag + + def apply_cartesian(x, inverse=False): + if inverse: + fct = self.target[self._space].scalar_dvol + else: + fct = self.domain[self._space].scalar_dvol + return fct * hartley(x) if fct != 1 else hartley(x) + + def jax_expr(x, inverse=False): + ap = partial(apply_cartesian, inverse=inverse) + if np.issubdtype(x.dtype.type, np.complexfloating): + return ap(x.real) + 1j * ap(x.imag) + return ap(x) + + self._jax_expr = jax_expr + self._jax_expr_inv = partial(jax_expr, inverse=True) + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) if utilities.iscomplextype(x.dtype): @@ -314,6 +372,7 @@ class HarmonicTransformOperator(LinearOperator): self._domain = self._op.domain self._target = self._op.target self._capability = self.TIMES | self.ADJOINT_TIMES + self._jax_expr = self._op.jax_expr def apply(self, x, mode): self._check_input(x, mode) diff --git a/src/operators/jax_operator.py b/src/operators/jax_operator.py index a70676cfbe98ff349ca1d2fff4502497d769110d..79275a71fdf638ab0565f76bfe034cba1913a5a6 100644 --- a/src/operators/jax_operator.py +++ b/src/operators/jax_operator.py @@ -18,6 +18,7 @@ from types import SimpleNamespace from warnings import warn import numpy as np +from functools import partial from .energy_operators import LikelihoodEnergyOperator from .linear_operator import LinearOperator @@ -59,17 +60,23 @@ class JaxOperator(Operator): self._domain = makeDomain(domain) self._target = makeDomain(target) self._func = jax.jit(func) - self._vjp = jax.jit(lambda x: jax.vjp(func, x)) + self._bwd = jax.jit(lambda x, y: jax.vjp(func, x)[1](y)[0]) self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1]) + self._jax_expr = func + def apply(self, x): from ..multi_domain import MultiDomain from ..sugar import is_linearization, makeField self._check_input(x) if is_linearization(x): - res, bwd = self._vjp(x.val.val) - fwd = lambda y: self._fwd(x.val.val, y) - jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=lambda x: bwd(x)[0]) + # TODO: Adapt the Linearization class to handle value_and_grad + # calls. Computing the pass through the function thrice (once now + # and twice when differentiating) is redundant and inefficient. + res = self._func(x.val.val) + bwd = partial(self._bwd, x.val.val) + fwd = partial(self._fwd, x.val.val) + jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=bwd) return x.new(makeField(self._target, _jax2np(res)), jac) res = _jax2np(self._func(x.val)) if isinstance(res, dict): @@ -157,6 +164,8 @@ class JaxLinearOperator(LinearOperator): self._func_T = func_T self._capability = self.TIMES | self.ADJOINT_TIMES + self._jax_expr = func + def apply(self, x, mode): from ..sugar import makeField self._check_input(x, mode) diff --git a/src/operators/linear_operator.py b/src/operators/linear_operator.py index c2186c8875195e824382324854057e4df5b050bd..7c6522441dd69f3cbd0a4d57d36a72adf8596d28 100644 --- a/src/operators/linear_operator.py +++ b/src/operators/linear_operator.py @@ -148,7 +148,7 @@ class LinearOperator(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` The input Field, defined on the Operator's domain or target, depending on mode. @@ -161,7 +161,7 @@ class LinearOperator(Operator): Returns ------- - Field + :class:`nifty8.field.Field` The processed Field defined on the Operator's target or domain, depending on mode. """ @@ -180,12 +180,12 @@ class LinearOperator(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` The input Field, defined on the Operator's domain. Returns ------- - Field + :class:`nifty8.field.Field` The processed Field defined on the Operator's target domain. """ return self.apply(x, self.TIMES) @@ -195,12 +195,12 @@ class LinearOperator(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` The input Field, defined on the Operator's target domain Returns ------- - Field + :class:`nifty8.field.Field` The processed Field defined on the Operator's domain. """ return self.apply(x, self.INVERSE_TIMES) @@ -210,12 +210,12 @@ class LinearOperator(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` The input Field, defined on the Operator's target domain Returns ------- - Field + :class:`nifty8.field.Field` The processed Field defined on the Operator's domain. """ return self.apply(x, self.ADJOINT_TIMES) @@ -225,12 +225,12 @@ class LinearOperator(Operator): Parameters ---------- - x : Field + x : :class:`nifty8.field.Field` The input Field, defined on the Operator's domain. Returns ------- - Field + :class:`nifty8.field.Field` The processed Field defined on the Operator's target domain. Notes diff --git a/src/operators/mask_operator.py b/src/operators/mask_operator.py index 11adb3b4efbc60de385331b970bcad8f87faa999..9d2584ffafc22119cf99be30561112a99b840695 100644 --- a/src/operators/mask_operator.py +++ b/src/operators/mask_operator.py @@ -31,7 +31,7 @@ class MaskOperator(LinearOperator): Parameters ---------- - flags : Field + flags : :class:`nifty8.field.Field` Is converted to boolean. Where True, the input field is flagged. """ def __init__(self, flags): @@ -42,6 +42,11 @@ class MaskOperator(LinearOperator): self._target = DomainTuple.make(UnstructuredDomain(self._flags.sum())) self._capability = self.TIMES | self.ADJOINT_TIMES + def mask(x): + return x[self._flags] + + self._jax_expr = mask + def apply(self, x, mode): self._check_input(x, mode) x = x.val diff --git a/src/operators/operator.py b/src/operators/operator.py index 60e7fc5a0c706243bc05df54dd4025d0179469f0..722c1946b70bd04e7aa436f8fa7edf67502c259b 100644 --- a/src/operators/operator.py +++ b/src/operators/operator.py @@ -21,6 +21,9 @@ from operator import add import numpy as np +from warnings import warn +from typing import Callable, Optional + from .. import pointwise from ..domain_tuple import DomainTuple from ..logger import logger @@ -112,6 +115,15 @@ class Operator(metaclass=NiftyMeta): """ return None + @property + def jax_expr(self) -> Optional[Callable]: + """Equivalent representation of the operator in JAX.""" + expr = getattr(self, "_jax_expr", None) + # NOTE, it is incredibly useful to enable this for debugging + # if expr is None: + # warn(f"no JAX expression associated with operator {self!r}") + return expr + def scale(self, factor): if not isinstance(factor, numbers.Number): raise TypeError(".scale() takes a number as input") @@ -250,7 +262,7 @@ class Operator(metaclass=NiftyMeta): Parameters ---------- - x : Field or MultiField + x : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` Input on which the operator shall act. Needs to be defined on :attr:`domain`. """ @@ -416,6 +428,32 @@ class _FunctionApplier(Operator): self._args = args self._kwargs = kwargs + try: + import jax.numpy as jnp + from jax import nn as jax_nn + + if funcname in pointwise.ptw_nifty2jax_dict: + jax_expr = pointwise.ptw_nifty2jax_dict[funcname] + elif hasattr(jnp, funcname): + jax_expr = getattr(jnp, funcname) + elif hasattr(jax_nn, funcname): + jax_expr = getattr(jax_nn, funcname) + else: + warn(f"unable to add JAX call for {funcname!r}") + jax_expr = None + + def jax_expr_part(x): # Partial insert with first open argument + return jax_expr(x, *args, **kwargs) + + if isinstance(self.domain, MultiDomain): + from functools import partial + from jax.tree_util import tree_map + + jax_expr_part = partial(tree_map, jax_expr_part) + self._jax_expr = jax_expr_part + except ImportError: + self._jax_expr = None + def apply(self, x): self._check_input(x) return x.ptw(self._funcname, *self._args, **self._kwargs) @@ -425,11 +463,22 @@ class _FunctionApplier(Operator): class _CombinedOperator(Operator): - def __init__(self, ops, _callingfrommake=False): + def __init__(self, ops, jax_ops, _callingfrommake=False): if not _callingfrommake: raise NotImplementedError self._ops = tuple(ops) + if all(callable(jop) for jop in jax_ops): + + def joined_jax_op(x): + for jop in reversed(jax_ops): + x = jop(x) + return x + + self._jax_expr = joined_jax_op + else: + self._jax_expr = None + @classmethod def unpack(cls, ops, res): for op in ops: @@ -444,12 +493,13 @@ class _CombinedOperator(Operator): res = cls.unpack(ops, []) if len(res) == 1: return res[0] - return cls(res, _callingfrommake=True) + jax_res = tuple(op.jax_expr for op in ops) + return cls(res, jax_res, _callingfrommake=True) class _OpChain(_CombinedOperator): - def __init__(self, ops, _callingfrommake=False): - super(_OpChain, self).__init__(ops, _callingfrommake) + def __init__(self, ops, jax_ops, _callingfrommake=False): + super(_OpChain, self).__init__(ops, jax_ops, _callingfrommake) self._domain = self._ops[-1].domain self._target = self._ops[0].target for i in range(1, len(self._ops)): @@ -486,6 +536,17 @@ class _OpProd(Operator): self._op1 = op1 self._op2 = op2 + lhs_has_jax = callable(self._op1.jax_expr) + rhs_has_jax = callable(self._op2.jax_expr) + if lhs_has_jax and rhs_has_jax: + + def joined_jax_expr(x): + return self._op1.jax_expr(x) * self._op2.jax_expr(x) + + self._jax_expr = joined_jax_expr + else: + self._jax_expr = None + def apply(self, x): from ..linearization import Linearization from ..sugar import makeOp @@ -529,6 +590,16 @@ class _OpSum(Operator): self._op1 = op1 self._op2 = op2 + try: + from ..re import unite + + def joined_jax_expr(x): + return unite(self._op1.jax_expr(x), self._op2.jax_expr(x)) + + self._jax_expr = joined_jax_expr + except ImportError: + self._jax_expr = None + def apply(self, x): self._check_input(x) return self._apply_operator_sum(x, [self._op1, self._op2]) diff --git a/src/operators/operator_adapter.py b/src/operators/operator_adapter.py index e16a43e3565b0d073ccd1afc9c75b24bb7c02f19..9f8da261e0afa8c0aa29fda0b8481015e95cddbb 100644 --- a/src/operators/operator_adapter.py +++ b/src/operators/operator_adapter.py @@ -38,7 +38,7 @@ class OperatorAdapter(LinearOperator): 3) adjoint inverse """ - def __init__(self, op, op_transform): + def __init__(self, op, op_transform, domain_dtype=float): self._op = op self._trafo = int(op_transform) if self._trafo < 1 or self._trafo > 3: @@ -47,6 +47,35 @@ class OperatorAdapter(LinearOperator): self._target = self._op._tgt(1 << self._trafo) self._capability = self._capTable[self._trafo][self._op.capability] + try: + from jax import eval_shape, linear_transpose + import jax.numpy as jnp + from jax.tree_util import tree_map, tree_all + + from ..nifty2jax import shapewithdtype_from_domain + from ..re import Field + + if callable(op.jax_expr) and self._trafo == self.ADJOINT_BIT: + def jax_expr(y): + op_domain = shapewithdtype_from_domain(op.domain, domain_dtype) + op_domain = Field(op_domain) if isinstance(y, Field) else op_domain + tentative_yshape = eval_shape(op.jax_expr, op_domain) + if not tree_all(tree_map(lambda a,b : jnp.can_cast(a.dtype, b.dtype), y, tentative_yshape)): + raise ValueError(f"wrong dtype during transposition:/got {tentative_yshape} and expected {y!r}") + y = tree_map(lambda c, d: c.astype(d.dtype, casting="safe", copy=False), y, tentative_yshape) + y_conj = tree_map(jnp.conj, y) + jax_expr_T = linear_transpose(op.jax_expr, op_domain) + return tree_map(jnp.conj, jax_expr_T(y_conj)[0]) + + self._jax_expr = jax_expr + elif hasattr(op, "_jax_expr_inv") and callable(op._jax_expr_inv) and self._trafo == self.INVERSE_BIT: + self._jax_expr = op._jax_expr_inv + self._jax_expr_inv = op._jax_expr + else: + self._jax_expr = None + except ImportError: + self._jax_expr = None + def _flip_modes(self, trafo): newtrafo = trafo ^ self._trafo return self._op if newtrafo == 0 \ diff --git a/src/operators/outer_product_operator.py b/src/operators/outer_product_operator.py index 72b8b71f233784e2087de6e1149669d610dd2e91..6dd6a4d4248b2ee9fb8ded42488a120700b4d6aa 100644 --- a/src/operators/outer_product_operator.py +++ b/src/operators/outer_product_operator.py @@ -15,10 +15,12 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +from functools import partial import numpy as np from ..domain_tuple import DomainTuple from ..field import Field +from ..multi_field import MultiField from .linear_operator import LinearOperator @@ -27,8 +29,8 @@ class OuterProduct(LinearOperator): Parameters --------- - domain: DomainTuple, the domain of the input field - field: Field + domain : DomainTuple, the domain of the input field + field : :class:`nifty8.field.Field` --------- """ def __init__(self, domain, field): @@ -38,6 +40,29 @@ class OuterProduct(LinearOperator): tuple(sub_d for sub_d in field.domain._dom + self._domain._dom)) self._capability = self.TIMES | self.ADJOINT_TIMES + try: + from ..re import Field as ReField + from jax import numpy as jnp + from jax.tree_util import tree_map + + a_j = ReField(field.val) if isinstance(field, (Field, MultiField)) else field + + def jax_expr(x): + # Preserve the input type + if not isinstance(x, ReField): + a_astype_x = a_j.val if isinstance(a_j, ReField) else a_j + else: + a_astype_x = a_j + + return tree_map( + partial(jnp.tensordot, axes=((), ())), + a_astype_x, x + ) + + self._jax_expr = jax_expr + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: diff --git a/src/operators/scaling_operator.py b/src/operators/scaling_operator.py index ab6e79f9a2770d500b2e3e8ba5a6a27b154668b9..6a557027836fc84a38ca08d0b747eb80d20bf2f4 100644 --- a/src/operators/scaling_operator.py +++ b/src/operators/scaling_operator.py @@ -66,6 +66,14 @@ class ScalingOperator(EndomorphicOperator): check_dtype_or_none(sampling_dtype, self._domain) self._dtype = sampling_dtype + try: + from jax import numpy as jnp + from functools import partial + + self._jax_expr = partial(jnp.multiply, factor) + except ImportError: + self._jax_expr = None + def apply(self, x, mode): from ..sugar import full diff --git a/src/operators/simple_linear_operators.py b/src/operators/simple_linear_operators.py index 336f19b1bfb84158b9757405e86b68c0235bc2fa..5e075cb3d7809334aec3a134810934391ca62806 100644 --- a/src/operators/simple_linear_operators.py +++ b/src/operators/simple_linear_operators.py @@ -16,6 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np +from functools import partial from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain @@ -32,7 +33,7 @@ class VdotOperator(LinearOperator): Parameters ---------- - field : Field or MultiField + field : :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` The field used to build the scalar product with the operator input """ def __init__(self, field): @@ -41,6 +42,13 @@ class VdotOperator(LinearOperator): self._target = DomainTuple.scalar_domain() self._capability = self.TIMES | self.ADJOINT_TIMES + try: + from ..re import vdot + + self._jax_expr = partial(vdot, field.val) + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_mode(mode) if mode == self.TIMES: @@ -61,6 +69,14 @@ class ConjugationOperator(EndomorphicOperator): self._domain = DomainTuple.make(domain) self._capability = self._all_ops + try: + from jax import numpy as jnp + from jax.tree_util import tree_map + + self._jax_expr = partial(tree_map, jnp.conjugate) + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) return x.conjugate() @@ -108,6 +124,14 @@ class Realizer(EndomorphicOperator): self._domain = DomainTuple.make(domain) self._capability = self.TIMES | self.ADJOINT_TIMES + try: + from jax import numpy as jnp + from jax.tree_util import tree_map + + self._jax_expr = partial(tree_map, jnp.real) + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) return x.real @@ -126,6 +150,14 @@ class Imaginizer(EndomorphicOperator): self._domain = DomainTuple.make(domain) self._capability = self.TIMES | self.ADJOINT_TIMES + try: + from jax import numpy as jnp + from jax.tree_util import tree_map + + self._jax_expr = partial(tree_map, jnp.imag) + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: @@ -166,6 +198,22 @@ class FieldAdapter(LinearOperator): self._target = MultiDomain.make({name: tmp[name]}) self._capability = self.TIMES | self.ADJOINT_TIMES + try: + from .. import re as jft + + def wrap(x): + return jft.Field({name: x}) + + def unwrap(x): + return x[name] + + if isinstance(tmp, DomainTuple): + self._jax_expr = unwrap + else: + self._jax_expr = wrap + except ImportError: + self._jax_expr = None + def apply(self, x, mode): self._check_input(x, mode) if isinstance(x, MultiField): @@ -310,6 +358,11 @@ class GeometryRemover(LinearOperator): self._target = DomainTuple.make(tgt) self._capability = self.TIMES | self.ADJOINT_TIMES + def identity(x): + return x + + self._jax_expr = identity + def apply(self, x, mode): self._check_input(x, mode) return x.cast_domain(self._tgt(mode)) diff --git a/src/operators/sum_operator.py b/src/operators/sum_operator.py index 3cbbb05106b75112e853b924e861beedd7b3a26e..65a8018ed800269783426baa3df48a64cdb97574 100644 --- a/src/operators/sum_operator.py +++ b/src/operators/sum_operator.py @@ -16,6 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from collections import defaultdict +import operator from ..sugar import domain_union from ..utilities import indent @@ -42,6 +43,24 @@ class SumOperator(LinearOperator): for op in ops: self._capability &= op.capability + try: + from ..re import unite + + def joined_jax_expr(x): + res = None + for op, n in zip(ops, neg): + tmp = op.jax_expr(x) + if res is None: + res = -tmp if n is True else tmp + else: + o = operator.sub if n is True else operator.add + res = unite(res, tmp, op=o) + return res + + self._jax_expr = joined_jax_expr + except ImportError: + self._jax_expr = None + @staticmethod def simplify(ops, neg): from .diagonal_operator import DiagonalOperator @@ -173,7 +192,7 @@ class SumOperator(LinearOperator): Individual operators of the sum. neg: list of bool Same length as ops. - If True then the equivalent operator gets a minus in the sum. + If True then the corresponding operator gets a minus in the sum. """ ops = tuple(ops) neg = tuple(neg) diff --git a/src/plot.py b/src/plot.py index d47fca6d6a3f5267b797770d36fa366247ad6163..efcba788a55ff146d0a810bfed50a9251b0cf523 100644 --- a/src/plot.py +++ b/src/plot.py @@ -575,7 +575,7 @@ class Plot: Parameters ---------- - f: Field or list of Field or None + f : :class:`nifty8.field.Field` or list of :class:`nifty8.field.Field` or None If `f` is a single Field, it must be defined on a single `RGSpace`, `PowerSpace`, `HPSpace`, `GLSpace`. If it is a list, all list members must be Fields defined over the diff --git a/src/pointwise.py b/src/pointwise.py index b709d2ba674b5ecccae9fdbb2ebbf5598bc0749b..a15d74bf34f8269da0fab50a4a5c4c53cb234280 100644 --- a/src/pointwise.py +++ b/src/pointwise.py @@ -153,3 +153,23 @@ ptw_dict = { "arctan": (np.arctan, lambda v: (np.arctan(v), 1./(1.+v**2))), "unitstep": (lambda v: _step_helper(v, False), lambda v: _step_helper(v, True)) } + + +def sigmoid_j(v): + from jax import numpy as jnp + + # NOTE, the sigmoid used in NIFTy is different to the one commonly referred + # to as sigmoid in most of the literature. + return 0.5 + (0.5 * jnp.tanh(v)) + + +def exponentiate_j(v, base): + from jax import numpy as jnp + + return jnp.power(base, v) + + +ptw_nifty2jax_dict = { + "sigmoid": sigmoid_j, + "exponentiate": exponentiate_j, +} diff --git a/src/probing.py b/src/probing.py index eae3ef916155708b699ba543ddb019949abcef1a..04068f0c7232c48379ff66787d7ebe007da907b0 100644 --- a/src/probing.py +++ b/src/probing.py @@ -87,7 +87,7 @@ def probe_with_posterior_samples(op, post_op, nprobes, dtype): Returns ------- - List of Field + List of :class:`nifty8.field.Field` List of two fields: the mean and the variance. ''' if not isinstance(op, EndomorphicOperator): @@ -129,7 +129,7 @@ def probe_diagonal(op, nprobes, random_type="pm1"): Returns ------- - Field + :class:`nifty8.field.Field` The estimated diagonal. ''' sc = StatCalculator() diff --git a/src/re/README.md b/src/re/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1dfa61822c8f6496f658c84d173acac7f062b578 --- /dev/null +++ b/src/re/README.md @@ -0,0 +1,24 @@ +# Re-envisioning NIFTy + +## JAX + +The (soft linked) code in this directory is a new interface for NIFTy written in JAX. +Some features of this new API are straight-forward re-implementations of features in NIFTy while other features are orthogonal to NIFTy and follow a different, usually more functional approach. +All essential pieces of NIFTy are implemented and the API is capable of (almost) fully replacing NIFTy's current NumPy based implementation. + +### Current Features + +* MAP +* MGVI +* geoVI +* Non-parametric correlated field + +### TODO + +The likelihood (or the Hamiltonian) probably is the object where it makes the most sense to translate to a different interface. +The minimization can be different depending on the API used but the likelihood should be a common denominator. +Inference schemes like MGVI, geoVI or MAP do not need to be similar nor should they be. +For all of these methods a more functional approach is desired instead. + +Overall, it would make sense to re-implement `optimize_kl` from NIFTy because it abstracts away many details of how MGVI, geoVI or MAP is implemented. +Furthermore, this would make transitioning from NumPy NIFTy to a JAX-based NIFTy more easy while at the same time allowing for many changes to the interfaces of MGVI, geoVI and MAP. diff --git a/src/re/__init__.py b/src/re/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c170487ff1b48622b6a6b26fa0bc3f7e447975e6 --- /dev/null +++ b/src/re/__init__.py @@ -0,0 +1,68 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from . import refine +from . import refine_util +from . import refine_chart +from . import lanczos +from . import structured_kernel_interpolation +from .conjugate_gradient import cg, static_cg +from .correlated_field import CorrelatedFieldMaker, non_parametric_amplitude +from .energy_operators import ( + Categorical, + Gaussian, + Poissonian, + StudentT, + VariableCovarianceGaussian, + VariableCovarianceStudentT, +) +from .field import Field +from .forest_util import ( + ShapeWithDtype, + assert_arithmetics, + dot, + has_arithmetics, + map_forest, + map_forest_mean, + norm, + shape, + size, + stack, + unite, + unstack, + vdot, + zeros_like, +) +from .hmc import generate_hmc_acc_rej, generate_nuts_tree +from .hmc_oo import HMCChain, NUTSChain +from .kl import ( + GeoMetricKL, + MetricKL, + geometrically_sample_standard_hamiltonian, + mean_hessp, + mean_metric, + mean_value_and_grad, + sample_standard_hamiltonian, +) +from .lanczos import stochastic_lq_logdet +from .likelihood import Likelihood, StandardHamiltonian +from .optimize import minimize, newton_cg, trust_ncg +from .refine_chart import CoordinateChart, RefinementField +from .stats_distributions import ( + invgamma_invprior, + invgamma_prior, + laplace_prior, + lognormal_invprior, + lognormal_prior, + normal_prior, + uniform_prior, +) +from .sugar import ( + ducktape, + ducktape_left, + interpolate, + mean, + mean_and_std, + random_like, + sum_of_squares, +) diff --git a/src/re/conjugate_gradient.py b/src/re/conjugate_gradient.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3415a06e360761531bbf49c205f52b94c330f0 --- /dev/null +++ b/src/re/conjugate_gradient.py @@ -0,0 +1,650 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +import sys +from datetime import datetime +from functools import partial +from jax import numpy as jnp +from jax import lax + +from typing import Any, Callable, NamedTuple, Optional, Tuple, Union + +from .forest_util import assert_arithmetics, common_type, size, where, zeros_like +from .forest_util import norm as jft_norm +from .sugar import doc_from, sum_of_squares + +HessVP = Callable[[jnp.ndarray], jnp.ndarray] + +N_RESET = 20 + + +class CGResults(NamedTuple): + x: jnp.ndarray + nit: Union[int, jnp.ndarray] + nfev: Union[int, jnp.ndarray] # number of matrix-evaluations + info: Union[int, jnp.ndarray] + success: Union[bool, jnp.ndarray] + + +def cg(mat, j, x0=None, *args, **kwargs) -> Tuple[Any, Union[int, jnp.ndarray]]: + """Solve `mat(x) = j` using Conjugate Gradient. `mat` must be callable and + represent a hermitian, positive definite matrix. + + Notes + ----- + If set, the parameters `absdelta` and `resnorm` always take precedence over + `tol` and `atol`. + """ + assert_arithmetics(j) + if x0 is not None: + assert_arithmetics(x0) + cg_res = _cg(mat, j, x0, *args, **kwargs) + return cg_res.x, cg_res.info + + +@doc_from(cg) +def static_cg(mat, j, x0=None, *args, **kwargs): + assert_arithmetics(j) + if x0 is not None: + assert_arithmetics(x0) + cg_res = _static_cg(mat, j, x0, *args, **kwargs) + return cg_res.x, cg_res.info + + +# Taken from nifty +def _cg( + mat, + j, + x0=None, + *, + absdelta=None, + resnorm=None, + norm_ord=None, + tol=1e-5, # taken from SciPy's linalg.cg + atol=0., + miniter=None, + maxiter=None, + name=None, + time_threshold=None, + _within_newton=False +) -> CGResults: + norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 + maxiter_fallback = 20 * size(j) # taken from SciPy's NewtonCG minimzer + miniter = min( + (6, maxiter if maxiter is not None else maxiter_fallback) + ) if miniter is None else miniter + maxiter = max( + (min((200, maxiter_fallback)), miniter) + ) if maxiter is None else maxiter + + if absdelta is None and resnorm is None: # fallback convergence criterion + resnorm = jnp.maximum(tol * jft_norm(j, ord=norm_ord, ravel=True), atol) + + common_dtp = common_type(j) + eps = 6. * jnp.finfo(common_dtp).eps # taken from SciPy's NewtonCG minimzer + tiny = 6. * jnp.finfo(common_dtp).tiny + + if x0 is None: + pos = zeros_like(j) + r = -j + d = r + # energy = .5xT M x - xT j + energy = 0. + nfev = 0 + else: + pos = x0 + r = mat(pos) - j + d = r + energy = float(((r - j) / 2).dot(pos)) + nfev = 1 + previous_gamma = float(sum_of_squares(r)) + if previous_gamma == 0: + info = 0 + return CGResults(x=pos, info=info, nit=0, nfev=nfev, success=True) + + info = -1 + i = 0 + for i in range(1, maxiter + 1): + q = mat(d) + nfev += 1 + + curv = float(d.dot(q)) + if curv == 0.: + if _within_newton: + info = 0 + break + nm = "CG" if name is None else name + raise ValueError(f"{nm}: zero curvature") + elif curv < 0.: + if _within_newton and i > 1: + info = 0 + break + elif _within_newton: + pos = previous_gamma / (-curv) * j + info = 0 + break + nm = "CG" if name is None else name + raise ValueError(f"{nm}: negative curvature") + alpha = previous_gamma / curv + pos = pos - alpha * d + if i % N_RESET == 0: + r = mat(pos) - j + nfev += 1 + else: + r = r - q * alpha + gamma = float(sum_of_squares(r)) + if time_threshold is not None and datetime.now() > time_threshold: + info = i + break + if gamma >= 0. and gamma <= tiny: + nm = "CG" if name is None else name + print(f"{nm}: gamma=0, converged!", file=sys.stderr) + info = 0 + break + if resnorm is not None: + norm = float(jft_norm(r, ord=norm_ord, ravel=True)) + if name is not None: + msg = f"{name}: |∇|:{norm:.6e} 🞋:{resnorm:.6e}" + print(msg, file=sys.stderr) + if norm < resnorm and i >= miniter: + info = 0 + break + if absdelta is not None or name is not None: + new_energy = float(((r - j) / 2).dot(pos)) + energy_diff = energy - new_energy + if name is not None: + msg = ( + f"{name}: Iteration {i} ⛰:{new_energy:+.6e} Δ⛰:{energy_diff:.6e}" + + (f" 🞋:{absdelta:.6e}" if absdelta is not None else "") + ) + print(msg, file=sys.stderr) + else: + new_energy = energy + if absdelta is not None: + neg_energy_eps = -eps * jnp.abs(new_energy) + if energy_diff < neg_energy_eps: + nm = "CG" if name is None else name + raise ValueError(f"{nm}: WARNING: energy increased") + if neg_energy_eps <= energy_diff < absdelta and i >= miniter: + info = 0 + break + energy = new_energy + d = d * max(0, gamma / previous_gamma) + r + previous_gamma = gamma + else: + nm = "CG" if name is None else name + print(f"{nm}: Iteration Limit Reached", file=sys.stderr) + info = i + return CGResults(x=pos, info=info, nit=i, nfev=nfev, success=info == 0) + + +def _static_cg( + mat, + j, + x0=None, + *, + absdelta=None, + resnorm=None, + norm_ord=None, + tol=1e-5, # taken from SciPy's linalg.cg + atol=0., + miniter=None, + maxiter=None, + name=None, + _within_newton=False, # TODO + **kwargs +) -> CGResults: + from jax.lax import cond, while_loop + + norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 + maxiter_fallback = 20 * size(j) # taken from SciPy's NewtonCG minimzer + miniter = jnp.minimum( + 6, maxiter if maxiter is not None else maxiter_fallback + ) if miniter is None else miniter + maxiter = jnp.maximum( + jnp.minimum(200, maxiter_fallback), miniter + ) if maxiter is None else maxiter + + if absdelta is None and resnorm is None: # fallback convergence criterion + resnorm = jnp.maximum(tol * jft_norm(j, ord=norm_ord, ravel=True), atol) + + common_dtp = common_type(j) + eps = 6. * jnp.finfo(common_dtp).eps # taken from SciPy's NewtonCG minimzer + tiny = 6. * jnp.finfo(common_dtp).tiny + + def continue_condition(v): + return v["info"] < -1 + + def cg_single_step(v): + info = v["info"] + pos, r, d, i = v["pos"], v["r"], v["d"], v["iteration"] + previous_gamma, previous_energy = v["gamma"], v["energy"] + + i += 1 + + q = mat(d) + curv = d.dot(q) + # ValueError("zero curvature in conjugate gradient") + info = jnp.where(curv == 0., -1, info) + alpha = previous_gamma / curv + # ValueError("implausible gradient scaling `alpha < 0`") + info = jnp.where(alpha < 0., -1, info) + pos = pos - alpha * d + r = cond( + i % N_RESET == 0, lambda x: mat(x["pos"]) - x["j"], + lambda x: x["r"] - x["q"] * x["alpha"], { + "pos": pos, + "j": j, + "r": r, + "q": q, + "alpha": alpha + } + ) + gamma = sum_of_squares(r) + + info = jnp.where( + (gamma >= 0.) & (gamma <= tiny) & (info != -1), 0, info + ) + if resnorm is not None: + norm = jft_norm(r, ord=norm_ord, ravel=True) + info = jnp.where( + (norm < resnorm) & (i >= miniter) & (info != -1), 0, info + ) + else: + norm = None + # Do not compute the energy if we do not check `absdelta` + if absdelta is not None or name is not None: + energy = ((r - j) / 2).dot(pos) + energy_diff = previous_energy - energy + else: + energy = previous_energy + energy_diff = None + if absdelta is not None: + neg_energy_eps = -eps * jnp.abs(energy) + # print(f"energy increased", file=sys.stderr) + info = jnp.where(energy_diff < neg_energy_eps, -1, info) + info = jnp.where( + (energy_diff >= neg_energy_eps) & (energy_diff < absdelta) & + (i >= miniter) & (info != -1), 0, info + ) + info = jnp.where((i >= maxiter) & (info != -1), i, info) + + d = d * jnp.maximum(0, gamma / previous_gamma) + r + + if name is not None: + from jax.experimental.host_callback import call + + def pp(arg): + msg = ( + ( + "{name}: |∇|:{norm:.6e} 🞋:{resnorm:.6e}\n" + if arg["resnorm"] is not None else "" + ) + "{name}: Iteration {i} ⛰:{energy:+.6e}" + + " Δ⛰:{energy_diff:.6e}" + ( + " 🞋:{absdelta:.6e}" + if arg["absdelta"] is not None else "" + ) + ( + "\n{name}: Iteration Limit Reached" + if arg["i"] == arg["maxiter"] else "" + ) + ) + print(msg.format(name=name, **arg), file=sys.stderr) + + printable_state = { + "i": i, + "energy": energy, + "energy_diff": energy_diff, + "absdelta": absdelta, + "norm": norm, + "resnorm": resnorm, + "maxiter": maxiter + } + call(pp, printable_state, result_shape=None) + + ret = { + "info": info, + "pos": pos, + "r": r, + "d": d, + "iteration": i, + "gamma": gamma, + "energy": energy + } + return ret + + if x0 is None: + pos = zeros_like(j) + r = -j + d = r + nfev = 0 + else: + pos = x0 + r = mat(pos) - j + d = r + nfev = 1 + energy = None + if absdelta is not None or name is not None: + if x0 is None: + # energy = .5xT M x - xT j + energy = jnp.array(0.) + else: + energy = ((r - j) / 2).dot(pos) + + gamma = sum_of_squares(r) + val = { + "info": jnp.array(-2, dtype=int), + "pos": pos, + "r": r, + "d": d, + "iteration": jnp.array(0), + "gamma": gamma, + "energy": energy + } + # Finish early if already converged in the initial iteration + val["info"] = jnp.where(gamma == 0., 0, val["info"]) + + val = while_loop(continue_condition, cg_single_step, val) + + i = val["iteration"] + info = val["info"] + nfev += i + i // N_RESET + return CGResults( + x=val["pos"], info=info, nit=i, nfev=nfev, success=info == 0 + ) + + +# The following is code adapted from Nicholas Mancuso to work with pytrees +class _QuadSubproblemResult(NamedTuple): + step: jnp.ndarray + hits_boundary: Union[bool, jnp.ndarray] + pred_f: Union[float, jnp.ndarray] + nit: Union[int, jnp.ndarray] + nfev: Union[int, jnp.ndarray] + njev: Union[int, jnp.ndarray] + nhev: Union[int, jnp.ndarray] + success: Union[bool, jnp.ndarray] + + +class _CGSteihaugState(NamedTuple): + z: jnp.ndarray + r: jnp.ndarray + d: jnp.ndarray + step: jnp.ndarray + energy: Union[None, float, jnp.ndarray] + hits_boundary: Union[bool, jnp.ndarray] + done: Union[bool, jnp.ndarray] + nit: Union[int, jnp.ndarray] + nhev: Union[int, jnp.ndarray] + + +def second_order_approx( + p: jnp.ndarray, + cur_val: Union[float, jnp.ndarray], + g: jnp.ndarray, + hessp_at_xk: HessVP, +) -> Union[float, jnp.ndarray]: + return cur_val + g.dot(p) + 0.5 * p.dot(hessp_at_xk(p)) + + +def get_boundaries_intersections( + z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray] +): # Adapted from SciPy + """Solve the scalar quadratic equation ||z + t d|| == trust_radius. + + This is like a line-sphere intersection. + + Return the two values of t, sorted from low to high. + """ + a = d.dot(d) + b = 2 * z.dot(d) + c = z.dot(z) - trust_radius**2 + sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c) + + # The following calculation is mathematically + # equivalent to: + # ta = (-b - sqrt_discriminant) / (2*a) + # tb = (-b + sqrt_discriminant) / (2*a) + # but produce smaller round off errors. + # Look at Matrix Computation p.97 + # for a better justification. + aux = b + jnp.copysign(sqrt_discriminant, b) + ta = -aux / (2 * a) + tb = -2 * c / aux + + ra, rb = where(ta < tb, (ta, tb), (tb, ta)) + return (ra, rb) + + +def _cg_steihaug_subproblem( + cur_val: Union[float, jnp.ndarray], + g: jnp.ndarray, + hessp_at_xk: HessVP, + *, + trust_radius: Union[float, jnp.ndarray], + tr_norm_ord: Union[None, int, float, jnp.ndarray] = None, + resnorm: Optional[float], + absdelta: Optional[float] = None, + norm_ord: Union[None, int, float, jnp.ndarray] = None, + miniter: Union[None, int] = None, + maxiter: Union[None, int] = None, + name=None +) -> _QuadSubproblemResult: + """ + Solve the subproblem using a conjugate gradient method. + + Parameters + ---------- + cur_val : Union[float, jnp.ndarray] + Objective value evaluated at the current state. + g : jnp.ndarray + Gradient value evaluated at the current state. + hessp_at_xk: Callable + Function that accepts a proposal vector and computes the result of a + Hessian-vector product. + trust_radius : float + Upper bound on how large a step proposal can be. + tr_norm_ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional + Order of the norm for computing the length of the next step. + norm_ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional + Order of the norm for testing convergence. + + Returns + ------- + result : _QuadSubproblemResult + Contains the step proposal, whether it is at radius boundary, and + meta-data regarding function calls and successful convergence. + + Notes + ----- + This is algorithm (7.2) of Nocedal and Wright 2nd edition. + Only the function that computes the Hessian-vector product is required. + The Hessian itself is not required, and the Hessian does + not need to be positive semidefinite. + """ + tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord # taken from JAX + norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 + maxiter_fallback = 20 * size(g) # taken from SciPy's NewtonCG minimzer + miniter = jnp.minimum( + 6, maxiter if maxiter is not None else maxiter_fallback + ) if miniter is None else miniter + maxiter = jnp.maximum( + jnp.minimum(200, maxiter_fallback), miniter + ) if maxiter is None else maxiter + + common_dtp = common_type(g) + eps = 6. * jnp.finfo( + common_dtp + ).eps # Inspired by SciPy's NewtonCG minimzer + + # second-order Taylor series approximation at the current values, gradient, + # and hessian + soa = partial( + second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk + ) + + # helpers for internal switches in the main CGSteihaug logic + def noop( + param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] + ) -> _CGSteihaugState: + iterp, z_next = param + return iterp + + def step1( + param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] + ) -> _CGSteihaugState: + iterp, z_next = param + z, d, nhev = iterp.z, iterp.d, iterp.nhev + + ta, tb = get_boundaries_intersections(z, d, trust_radius) + pa = z + ta * d + pb = z + tb * d + p_boundary = where(soa(pa) < soa(pb), pa, pb) + return iterp._replace( + step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True + ) + + def step2( + param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] + ) -> _CGSteihaugState: + iterp, z_next = param + z, d = iterp.z, iterp.d + + ta, tb = get_boundaries_intersections(z, d, trust_radius) + p_boundary = z + tb * d + return iterp._replace(step=p_boundary, hits_boundary=True, done=True) + + def step3( + param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] + ) -> _CGSteihaugState: + iterp, z_next = param + return iterp._replace(step=z_next, hits_boundary=False, done=True) + + # initialize the step + p_origin = zeros_like(g) + + # init the state for the first iteration + z = p_origin + r = g + d = -r + energy = 0. if absdelta is not None or name is not None else None + init_param = _CGSteihaugState( + z=z, + r=r, + d=d, + step=p_origin, + energy=energy, + hits_boundary=False, + done=maxiter == 0, + nit=0, + nhev=0 + ) + + # Search for the min of the approximation of the objective function. + def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState: + z, r, d = iterp.z, iterp.r, iterp.d + energy, nit = iterp.energy, iterp.nit + + nit += 1 + + Bd = hessp_at_xk(d) + dBd = d.dot(Bd) + + r_squared = r.dot(r) + alpha = r_squared / dBd + z_next = z + alpha * d + + r_next = r + alpha * Bd + r_next_squared = r_next.dot(r_next) + + beta_next = r_next_squared / r_squared + d_next = -r_next + beta_next * d + + accept_z_next = nit >= maxiter + if norm_ord == 2: + r_next_norm = jnp.sqrt(r_next_squared) + else: + r_next_norm = jft_norm(r_next, ord=norm_ord, ravel=True) + accept_z_next |= r_next_norm < resnorm + if absdelta is not None or name is not None: + # Relative to a plain CG, `z_next` is negative + energy_next = ((r_next + g) / 2).dot(z_next) + energy_diff = energy - energy_next + else: + energy_next = energy + energy_diff = jnp.nan + if absdelta is not None: + neg_energy_eps = -eps * jnp.abs(energy) + accept_z_next |= (energy_diff >= neg_energy_eps + ) & (energy_diff < absdelta) & (nit >= miniter) + + # include a junk switch to catch the case where none should be executed + z_next_norm = jft_norm(z_next, ord=tr_norm_ord, ravel=True) + index = jnp.argmax( + jnp.array( + [False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next] + ) + ) + iterp = lax.switch(index, [noop, step1, step2, step3], (iterp, z_next)) + + iterp = iterp._replace( + z=z_next, + r=r_next, + d=d_next, + energy=energy_next, + nhev=iterp.nhev + 1, + nit=nit + ) + if name is not None: + from jax.experimental.host_callback import call + + def pp(arg): + msg = ( + "{name}: |∇|:{r_norm:.6e} 🞋:{resnorm:.6e} ↗:{tr:.6e}" + " ☞:{case:1d} #∇²:{nhev:02d}" + "\n{name}: Iteration {i} ⛰:{energy:+.6e} Δ⛰:{energy_diff:.6e}" + + ( + " 🞋:{absdelta:.6e}" + if arg["absdelta"] is not None else "" + ) + ( + "\n{name}: Iteration Limit Reached" + if arg["i"] == arg["maxiter"] else "" + ) + ) + print(msg.format(name=name, **arg), file=sys.stderr) + + printable_state = { + "i": nit, + "energy": iterp.energy, + "energy_diff": energy_diff, + "absdelta": absdelta, + "tr": trust_radius, + "r_norm": r_next_norm, + "resnorm": resnorm, + "nhev": iterp.nhev, + "case": index, + "maxiter": maxiter + } + call(pp, printable_state, result_shape=None) + + return iterp + + def cond_f(iterp: _CGSteihaugState) -> bool: + return jnp.logical_not(iterp.done) + + # perform inner optimization to solve the constrained + # quadratic subproblem using cg + result = lax.while_loop(cond_f, body_f, init_param) + + pred_f = soa(result.step) + result = _QuadSubproblemResult( + step=result.step, + hits_boundary=result.hits_boundary, + pred_f=pred_f, + nit=result.nit, + nfev=0, + njev=0, + nhev=result.nhev + 1, + success=True + ) + + return result diff --git a/src/re/correlated_field.py b/src/re/correlated_field.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6da97379ffd3a20a2260a4aa57e4d25b3f59ce --- /dev/null +++ b/src/re/correlated_field.py @@ -0,0 +1,511 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from collections.abc import Mapping +from functools import partial +import sys +from typing import Callable, Dict, Optional, Tuple, Union + +from jax import numpy as jnp +import numpy as np + +from .forest_util import ShapeWithDtype +from .stats_distributions import lognormal_prior, normal_prior +from .sugar import ducktape + + +def _safe_assert(condition): + if not condition: + raise AssertionError() + + +def hartley(p, axes=None): + from jax.numpy import fft + + tmp = fft.fftn(p, axes=axes) + return tmp.real + tmp.imag + + +def get_fourier_mode_distributor( + shape: Union[tuple, int], distances: Union[tuple, float] +): + """Get the unique lengths of the Fourier modes, a mapping from a mode to + its length index and the multiplicity of each unique Fourier mode length. + + Parameters + ---------- + shape : tuple of int or int + Position-space shape. + distances : tuple of float or float + Position-space distances. + + Returns + ------- + mode_length_idx : jnp.ndarray + Index in power-space for every mode in harmonic-space. Can be used to + distribute power from a power-space to the full harmonic domain. + unique_mode_length : jnp.ndarray + Unique length of Fourier modes. + mode_multiplicity : jnp.ndarray + Multiplicity for each unique Fourier mode length. + """ + shape = (shape, ) if isinstance(shape, int) else tuple(shape) + + # Compute length of modes + mspc_distances = 1. / (jnp.array(shape) * jnp.array(distances)) + m_length = jnp.arange(shape[0], dtype=jnp.float64) + m_length = jnp.minimum(m_length, shape[0] - m_length) * mspc_distances[0] + if len(shape) != 1: + m_length *= m_length + for i in range(1, len(shape)): + tmp = jnp.arange(shape[i], dtype=jnp.float64) + tmp = jnp.minimum(tmp, shape[i] - tmp) * mspc_distances[i] + tmp *= tmp + m_length = jnp.expand_dims(m_length, axis=-1) + tmp + m_length = jnp.sqrt(m_length) + + # Construct an array of unique mode lengths + uniqueness_rtol = 1e-12 + um = jnp.unique(m_length) + tol = uniqueness_rtol * um[-1] + um = um[jnp.diff(jnp.append(um, 2 * um[-1])) > tol] + # Group modes based on their length and store the result as power + # distributor + binbounds = 0.5 * (um[:-1] + um[1:]) + m_length_idx = jnp.searchsorted(binbounds, m_length) + m_count = jnp.bincount(m_length_idx.ravel(), minlength=um.size) + if jnp.any(m_count == 0) or um.shape != m_count.shape: + raise RuntimeError("invalid harmonic mode(s) encountered") + + return m_length_idx, um, m_count + + +def _twolog_integrate(log_vol, x): + # Map the space to the one for the relative log-modes, i.e. pad the space + # of the log volume + twolog = jnp.empty((2 + log_vol.shape[0], )) + twolog = twolog.at[0].set(0.) + twolog = twolog.at[1].set(0.) + + twolog = twolog.at[2:].set(jnp.cumsum(x[1], axis=0)) + twolog = twolog.at[2:].set( + (twolog[2:] + twolog[1:-1]) / 2. * log_vol + x[0] + ) + twolog = twolog.at[2:].set(jnp.cumsum(twolog[2:], axis=0)) + return twolog + + +def _remove_slope(rel_log_mode_dist, x): + sc = rel_log_mode_dist / rel_log_mode_dist[-1] + return x - x[-1] * sc + + +def non_parametric_amplitude( + domain: Mapping, + fluctuations: Callable, + loglogavgslope: Callable, + flexibility: Optional[Callable] = None, + asperity: Optional[Callable] = None, + prefix: str = "", + kind: str = "amplitude", +) -> Tuple[Callable, Dict[str, ShapeWithDtype]]: + """Constructs an function computing the amplitude of a non-parametric power + spectrum + + See + :class:`nifty8.re.correlated_field.CorrelatedFieldMaker.add_fluctuations` + for more details on the parameters. + + See also + -------- + `Variable structures in M87* from space, time and frequency resolved + interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp + and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and + Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_ + `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_ + """ + totvol = domain.get("position_space_total_volume", 1.) + rel_log_mode_len = domain["relative_log_mode_lengths"] + mode_multiplicity = domain["mode_multiplicity"] + log_vol = domain.get("log_volume") + + ptree = {} + fluctuations = ducktape(fluctuations, prefix + "fluctuations") + ptree[prefix + "fluctuations"] = ShapeWithDtype(()) + loglogavgslope = ducktape(loglogavgslope, prefix + "loglogavgslope") + ptree[prefix + "loglogavgslope"] = ShapeWithDtype(()) + if flexibility is not None: + flexibility = ducktape(flexibility, prefix + "flexibility") + ptree[prefix + "flexibility"] = ShapeWithDtype(()) + # Register the parameters for the spectrum + _safe_assert(log_vol is not None) + _safe_assert(rel_log_mode_len.ndim == log_vol.ndim == 1) + ptree[prefix + "spectrum"] = ShapeWithDtype((2, ) + log_vol.shape) + if asperity is not None: + asperity = ducktape(asperity, prefix + "asperity") + ptree[prefix + "asperity"] = ShapeWithDtype(()) + + def correlate(primals: Mapping) -> jnp.ndarray: + flu = fluctuations(primals) + slope = loglogavgslope(primals) + slope *= rel_log_mode_len + ln_spectrum = slope + + if flexibility is not None: + _safe_assert(log_vol is not None) + xi_spc = primals[prefix + "spectrum"] + flx = flexibility(primals) + sig_flx = flx * jnp.sqrt(log_vol) + sig_flx = jnp.broadcast_to(sig_flx, (2, ) + log_vol.shape) + + if asperity is None: + shift = jnp.stack( + (log_vol / jnp.sqrt(12.), jnp.ones_like(log_vol)), axis=0 + ) + asp = shift * sig_flx * xi_spc + else: + asp = asperity(primals) + shift = jnp.stack( + (log_vol**2 / 12., jnp.ones_like(log_vol)), axis=0 + ) + sig_asp = jnp.broadcast_to( + jnp.array([[asp], [0.]]), shift.shape + ) + asp = jnp.sqrt(shift + sig_asp) * sig_flx * xi_spc + + twolog = _twolog_integrate(log_vol, asp) + wo_slope = _remove_slope(rel_log_mode_len, twolog) + ln_spectrum += wo_slope + + # Exponentiate and norm the power spectrum + spectrum = jnp.exp(ln_spectrum) + # Take the sqrt of the integral of the slope w/o fluctuations and + # zero-mode while taking into account the multiplicity of each mode + if kind.lower() == "amplitude": + norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:]**2)) + norm /= jnp.sqrt(totvol) # Due to integral in harmonic space + amplitude = flu * (jnp.sqrt(totvol) / norm) * spectrum + elif kind.lower() == "power": + norm = jnp.sqrt(jnp.sum(mode_multiplicity[1:] * spectrum[1:])) + norm /= jnp.sqrt(totvol) # Due to integral in harmonic space + amplitude = flu * (jnp.sqrt(totvol) / norm) * jnp.sqrt(spectrum) + else: + raise ValueError(f"invalid kind specified {kind!r}") + amplitude = amplitude.at[0].set(totvol) + return amplitude + + return correlate, ptree + + +class CorrelatedFieldMaker(): + """Construction helper for hierarchical correlated field models. + + The correlated field models are parametrized by creating square roots of + power spectrum operators ("amplitudes") via calls to + :func:`add_fluctuations*` that act on the targeted field subdomains. + During creation of the :class:`CorrelatedFieldMaker`, a global offset from + zero of the field model can be defined and an operator applying + fluctuations around this offset is parametrized. + + Creation of the model operator is completed by calling the method + :func:`finalize`, which returns the configured operator. + + See the methods initialization, :func:`add_fluctuations` and + :func:`finalize` for further usage information.""" + def __init__(self, prefix: str): + """Instantiate a CorrelatedFieldMaker object. + + Parameters + ---------- + prefix : string + Prefix to the names of the domains of the cf operator to be made. + This determines the names of the operator domain. + """ + self._azm = None + self._offset_mean = None + self._fluctuations = [] + self._target_subdomains = [] + self._parameter_tree = {} + + self._prefix = prefix + + def add_fluctuations( + self, + shape: Union[tuple, int], + distances: Union[tuple, float], + fluctuations: Union[tuple, Callable], + loglogavgslope: Union[tuple, Callable], + flexibility: Union[tuple, Callable, None] = None, + asperity: Union[tuple, Callable, None] = None, + prefix: str = "", + harmonic_domain_type: str = "fourier", + non_parametric_kind: str = "amplitude", + ): + """Adds a correlation structure to the to-be-made field. + + Correlations are described by their power spectrum and the subdomain on + which they apply. + + Multiple calls to `add_fluctuations` are possible, in which case + the constructed field will have the outer product of the individual + power spectra as its global power spectrum. + + The parameters `fluctuations`, `flexibility`, `asperity` and + `loglogavgslope` configure either the amplitude or the power + spectrum model used on the target field subdomain of type + `harmonic_domain_type`. It is assembled as the sum of a power + law component (linear slope in log-log + amplitude-frequency-space), a smooth varying component + (integrated Wiener process) and a ragged component + (un-integrated Wiener process). + + Parameters + ---------- + shape : tuple of int + Shape of the position space domain + distances : tuple of float or float + Distances in the position space domain + fluctuations : tuple of float (mean, std) or callable + Total spectral energy, i.e. amplitude of the fluctuations + (by default a priori log-normal distributed) + loglogavgslope : tuple of float (mean, std) or callable + Power law component exponent + (by default a priori normal distributed) + flexibility : tuple of float (mean, std) or callable or None + Amplitude of the non-power-law power spectrum component + (by default a priori log-normal distributed) + asperity : tuple of float (mean, std) or callable or None + Roughness of the non-power-law power spectrum component; use it to + accommodate single frequency peak + (by default a priori log-normal distributed) + prefix : str + Prefix of the power spectrum parameter domain names + harmonic_domain_type : str + Description of the harmonic partner domain in which the amplitude + lives + + See also + -------- + `Variable structures in M87* from space, time and frequency resolved + interferometry`, Arras, Philipp and Frank, Philipp and Haim, Philipp + and Knollmüller, Jakob and Leike, Reimar and Reinecke, Martin and + Enßlin, Torsten, `<https://arxiv.org/abs/2002.05218>`_ + `<http://dx.doi.org/10.1038/s41550-021-01548-0>`_ + """ + shape = (shape, ) if isinstance(shape, int) else tuple(shape) + distances = tuple(np.broadcast_to(distances, jnp.shape(shape))) + totvol = jnp.prod(jnp.array(shape) * jnp.array(distances)) + + # Pre-compute lengths of modes and indices for distributing power + # TODO: cache results such that only references are used afterwards + domain = { + "position_space_shape": shape, + "position_space_total_volume": totvol, + "position_space_distances": distances, + "harmonic_domain_type": harmonic_domain_type.lower() + } + if harmonic_domain_type.lower() == "fourier": + domain["harmonic_space_shape"] = shape + m_length_idx, um, m_count = get_fourier_mode_distributor( + shape, distances + ) + domain["power_distributor"] = m_length_idx + domain["mode_multiplicity"] = m_count + + # Transform the unique modes to log-space for the amplitude model + um = um.at[1:].set(jnp.log(um[1:])) + um = um.at[1:].add(-um[1]) + _safe_assert(um[0] == 0.) + domain["relative_log_mode_lengths"] = um + log_vol = um[2:] - um[1:-1] + _safe_assert(um.shape[0] - 2 == log_vol.shape[0]) + domain["log_volume"] = log_vol + else: + ve = f"invalid `harmonic_domain_type` {harmonic_domain_type!r}" + raise ValueError(ve) + + flu = fluctuations + if isinstance(flu, (tuple, list)): + flu = lognormal_prior(*flu) + elif not callable(flu): + te = f"invalid `fluctuations` specified; got '{type(fluctuations)}'" + raise TypeError(te) + slp = loglogavgslope + if isinstance(slp, (tuple, list)): + slp = normal_prior(*slp) + elif not callable(slp): + te = f"invalid `loglogavgslope` specified; got '{type(loglogavgslope)}'" + raise TypeError(te) + + flx = flexibility + if isinstance(flx, (tuple, list)): + flx = lognormal_prior(*flx) + elif flx is not None and not callable(flx): + te = f"invalid `flexibility` specified; got '{type(flexibility)}'" + raise TypeError(te) + asp = asperity + if isinstance(asp, (tuple, list)): + asp = lognormal_prior(*asp) + elif asp is not None and not callable(asp): + te = f"invalid `asperity` specified; got '{type(asperity)}'" + raise TypeError(te) + + npa, ptree = non_parametric_amplitude( + domain=domain, + fluctuations=flu, + loglogavgslope=slp, + flexibility=flx, + asperity=asp, + prefix=self._prefix + prefix, + kind=non_parametric_kind, + ) + self._fluctuations.append(npa) + self._target_subdomains.append(domain) + self._parameter_tree.update(ptree) + + def set_amplitude_total_offset( + self, offset_mean: float, offset_std: Union[tuple, Callable] + ): + """Sets the zero-mode for the combined amplitude operator + + Parameters + ---------- + offset_mean : float + Mean offset from zero of the correlated field to be made. + offset_std : tuple of float or callable + Mean standard deviation and standard deviation of the standard + deviation of the offset. No, this is not a word duplication. + (By default a priori log-normal distributed) + """ + if self._offset_mean is not None and self._azm is not None: + msg = "Overwriting the previous mean offset and zero-mode" + print(msg, file=sys.stderr) + + self._offset_mean = offset_mean + zm = offset_std + if not callable(zm): + if zm is None or len(zm) != 2: + raise TypeError(f"`offset_std` of invalid type {type(zm)!r}") + zm = lognormal_prior(*zm) + + self._azm = ducktape(zm, self._prefix + "zeromode") + self._parameter_tree[self._prefix + "zeromode"] = ShapeWithDtype(()) + + @property + def amplitude_total_offset(self) -> Callable: + """Returns the total offset of the amplitudes""" + if self._azm is None: + nie = "You need to set the `amplitude_total_offset` first" + raise NotImplementedError(nie) + return self._azm + + @property + def azm(self): + """Alias for `amplitude_total_offset`""" + return self.amplitude_total_offset + + @property + def fluctuations(self) -> Tuple[Callable, ...]: + """Returns the added fluctuations, i.e. un-normalized amplitudes + + Their scales are only meaningful relative to one another. Their + absolute scale bares no information. + """ + return tuple(self._fluctuations) + + def get_normalized_amplitudes(self) -> Tuple[Callable, ...]: + """Returns the normalized amplitude operators used in the final model + + The amplitude operators are corrected for the otherwise degenerate + zero-mode. Their scales are only meaningful relative to one another. + Their absolute scale bares no information. + """ + def _mk_normed_amp(amp): # Avoid late binding + def normed_amplitude(p): + return amp(p).at[1:].mul(1. / self.azm(p)) + + return normed_amplitude + + return tuple(_mk_normed_amp(amp) for amp in self._fluctuations) + + @property + def amplitude(self) -> Callable: + """Returns the added fluctuation, i.e. un-normalized amplitude""" + if len(self._fluctuations) > 1: + s = ( + 'If more than one spectrum is present in the model,' + ' no unique set of amplitudes exist because only the' + ' relative scale is determined.' + ) + raise NotImplementedError(s) + amp = self._fluctuations[0] + + def ampliude_w_zm(p): + return amp(p).at[0].mul(self.azm(p)) + + return ampliude_w_zm + + @property + def power_spectrum(self) -> Callable: + """Returns the power spectrum""" + amp = self.amplitude + + def power(p): + return amp(p)**2 + + return power + + def finalize(self) -> Tuple[Callable, Dict[str, ShapeWithDtype]]: + """Finishes off the model construction process and returns the + constructed operator. + """ + harmonic_transforms = [] + excitation_shape = () + for sub_dom in self._target_subdomains: + sub_shp = None + sub_shp = sub_dom["harmonic_space_shape"] + excitation_shape += sub_shp + n = len(excitation_shape) + axes = tuple(range(n - len(sub_shp), n)) + + # TODO: Generalize to complex + harmonic_dvol = 1. / sub_dom["position_space_total_volume"] + harmonic_transforms.append((harmonic_dvol, partial(hartley, axes=axes))) + # Register the parameters for the excitations in harmonic space + # TODO: actually account for the dtype here + pfx = self._prefix + "xi" + self._parameter_tree[pfx] = ShapeWithDtype(excitation_shape) + + def outer_harmonic_transform(p): + harmonic_dvol, ht = harmonic_transforms[0] + outer = harmonic_dvol * ht(p) + for harmonic_dvol, ht in harmonic_transforms[1:]: + outer = harmonic_dvol * ht(outer) + return outer + + def _mk_expanded_amp(amp, sub_dom): # Avoid late binding + def expanded_amp(p): + return amp(p)[sub_dom["power_distributor"]] + + return expanded_amp + + expanded_amplitudes = [] + namps = self.get_normalized_amplitudes() + for amp, sub_dom in zip(namps, self._target_subdomains): + expanded_amplitudes.append(_mk_expanded_amp(amp, sub_dom)) + + def outer_amplitude(p): + outer = expanded_amplitudes[0](p) + for amp in expanded_amplitudes[1:]: + # NOTE, the order is important here and must match with the + # excitations + # TODO, use functions instead and utilize numpy's casting + outer = jnp.tensordot(outer, amp(p), axes=0) + return outer + + def correlated_field(p): + ea = outer_amplitude(p) + cf_h = self.azm(p) * ea * p[self._prefix + "xi"] + return self._offset_mean + outer_harmonic_transform(cf_h) + + return correlated_field, self._parameter_tree.copy() diff --git a/src/re/disable_jax_control_flow.py b/src/re/disable_jax_control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..3f88d92d8dd2bcc7f3134d8a71a8e65743e2e3b7 --- /dev/null +++ b/src/re/disable_jax_control_flow.py @@ -0,0 +1,36 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from jax import lax + +_DISABLE_CONTROL_FLOW_PRIM = False + + +def cond(pred, true_fun, false_fun, operand): + if _DISABLE_CONTROL_FLOW_PRIM: + if pred: + return true_fun(operand) + else: + return false_fun(operand) + else: + return lax.cond(pred, true_fun, false_fun, operand) + + +def while_loop(cond_fun, body_fun, init_val): + if _DISABLE_CONTROL_FLOW_PRIM: + val = init_val + while cond_fun(val): + val = body_fun(val) + return val + else: + return lax.while_loop(cond_fun, body_fun, init_val) + + +def fori_loop(lower, upper, body_fun, init_val): + if _DISABLE_CONTROL_FLOW_PRIM: + val = init_val + for i in range(int(lower), int(upper)): + val = body_fun(i, val) + return val + else: + return lax.fori_loop(lower, upper, body_fun, init_val) diff --git a/src/re/energy_operators.py b/src/re/energy_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d3eef606baf85aff624e234ef4f3a2c48e7ee2 --- /dev/null +++ b/src/re/energy_operators.py @@ -0,0 +1,385 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from typing import Callable, Optional, Tuple + +import sys +from jax import numpy as jnp +from jax.tree_util import tree_map + +from .forest_util import ShapeWithDtype +from .likelihood import Likelihood + + +def standard_t(nwr, dof): + return jnp.sum(jnp.log1p(nwr**2 / dof) * (dof + 1)) / 2 + + +def _shape_w_fixed_dtype(dtype): + def shp_w_dtp(e): + return ShapeWithDtype(jnp.shape(e), dtype) + + return shp_w_dtp + + +def _get_cov_inv_and_std_inv( + cov_inv: Optional[Callable], + std_inv: Optional[Callable], + primals=None +) -> Tuple[Callable, Callable]: + if not cov_inv and not std_inv: + + def cov_inv(tangents): + return tangents + + def std_inv(tangents): + return tangents + + elif not cov_inv: + wm = ( + "assuming a diagonal covariance matrix" + ";\nsetting `cov_inv` to `std_inv(jnp.ones_like(data))**2`" + ) + print(wm, file=sys.stderr) + noise_std_inv_sq = std_inv(tree_map(jnp.ones_like, primals))**2 + + def cov_inv(tangents): + return noise_std_inv_sq * tangents + + elif not std_inv: + wm = ( + "assuming a diagonal covariance matrix" + ";\nsetting `std_inv` to `cov_inv(jnp.ones_like(data))**0.5`" + ) + print(wm, file=sys.stderr) + noise_cov_inv_sqrt = tree_map( + jnp.sqrt, cov_inv(tree_map(jnp.ones_like, primals)) + ) + + def std_inv(tangents): + return noise_cov_inv_sqrt * tangents + + if not (callable(cov_inv) and callable(std_inv)): + raise ValueError("received un-callable input") + return cov_inv, std_inv + + +def Gaussian( + data, + noise_cov_inv: Optional[Callable] = None, + noise_std_inv: Optional[Callable] = None +): + """Gaussian likelihood of the data + + Parameters + ---------- + data : tree-like structure of jnp.ndarray and float + Data with additive noise following a Gaussian distribution. + noise_cov_inv : callable acting on type of data + Function applying the inverse noise covariance of the Gaussian. + noise_std_inv : callable acting on type of data + Function applying the square root of the inverse noise covariance. + + Notes + ----- + If `noise_std_inv` is `None` it is inferred by assuming a diagonal noise + covariance, i.e. by applying it to a vector of ones and taking the square + root. If both `noise_cov_inv` and `noise_std_inv` are `None`, a unit + covariance is assumed. + """ + noise_cov_inv, noise_std_inv = _get_cov_inv_and_std_inv( + noise_cov_inv, noise_std_inv, data + ) + + def hamiltonian(primals): + p_res = primals - data + return 0.5 * p_res.ravel().dot(noise_cov_inv(p_res).ravel()) + + def metric(primals, tangents): + return noise_cov_inv(tangents) + + def left_sqrt_metric(primals, tangents): + return noise_std_inv(tangents) + + def transformation(primals): + return noise_std_inv(primals) + + lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, data) + + return Likelihood( + hamiltonian, + transformation=transformation, + left_sqrt_metric=left_sqrt_metric, + metric=metric, + lsm_tangents_shape=lsm_tangents_shape + ) + + +def StudentT( + data, + dof, + noise_cov_inv: Optional[Callable] = None, + noise_std_inv: Optional[Callable] = None +): + """Student's t likelihood of the data + + Parameters + ---------- + data : tree-like structure of jnp.ndarray and float + Data with additive noise following a Gaussian distribution. + dof : tree-like structure of jnp.ndarray and float + Degree-of-freedom parameter of Student's t distribution. + noise_cov_inv : callable acting on type of data + Function applying the inverse noise covariance of the Gaussian. + noise_std_inv : callable acting on type of data + Function applying the square root of the inverse noise covariance. + + Notes + ----- + If `noise_std_inv` is `None` it is inferred by assuming a diagonal noise + covariance, i.e. by applying it to a vector of ones and taking the square + root. If both `noise_cov_inv` and `noise_std_inv` are `None`, a unit + covariance is assumed. + """ + noise_cov_inv, noise_std_inv = _get_cov_inv_and_std_inv( + noise_cov_inv, noise_std_inv, data + ) + + def hamiltonian(primals): + """ + primals : mean + """ + return standard_t(noise_std_inv(data - primals), dof) + + def metric(primals, tangents): + """ + primals, tangent : mean + """ + return noise_cov_inv((dof + 1) / (dof + 3) * tangents) + + def left_sqrt_metric(primals, tangents): + """ + primals, tangents : mean + """ + return noise_std_inv(jnp.sqrt((dof + 1) / (dof + 3)) * tangents) + + def transformation(primals): + """ + primals : mean + """ + return noise_std_inv(jnp.sqrt((dof + 1) / (dof + 3)) * primals) + + lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, data) + + return Likelihood( + hamiltonian, + transformation=transformation, + left_sqrt_metric=left_sqrt_metric, + metric=metric, + lsm_tangents_shape=lsm_tangents_shape + ) + + +def Poissonian(data, sampling_dtype=float): + """Computes the negative log-likelihood, i.e. the Hamiltonians of an + expected count field constrained by Poissonian count data. + + Represents up to an f-independent term :math:`log(d!)`: + + .. math :: + E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f), + + where f is a field in data space of the expectation values for the counts. + + Parameters + ---------- + data : ndarray of uint + Data field with counts. Needs to have integer dtype and all values need + to be non-negative. + sampling_dtype : dtype, optional + Data-type for sampling. + """ + from .forest_util import common_type + + dtp = common_type(data) + if not jnp.issubdtype(dtp, jnp.integer): + raise TypeError("`data` of invalid type") + if jnp.any(data < 0): + raise ValueError("`data` may not be negative") + + def hamiltonian(primals): + return jnp.sum(primals) - jnp.vdot(jnp.log(primals), data) + + def metric(primals, tangents): + return tangents / primals + + def left_sqrt_metric(primals, tangents): + return tangents / jnp.sqrt(primals) + + def transformation(primals): + return jnp.sqrt(primals) * 2. + + lsm_tangents_shape = tree_map(_shape_w_fixed_dtype(sampling_dtype), data) + + return Likelihood( + hamiltonian, + transformation=transformation, + left_sqrt_metric=left_sqrt_metric, + metric=metric, + lsm_tangents_shape=lsm_tangents_shape + ) + + +def VariableCovarianceGaussian(data): + """Gaussian likelihood of the data with a variable covariance + + Parameters + ---------- + data : tree-like structure of jnp.ndarray and float + Data with additive noise following a Gaussian distribution. + + Notes + ----- + The likelihood acts on a tuple of (mean, std_inv). + """ + from .sugar import sum_of_squares + + def hamiltonian(primals): + """ + primals : pair of (mean, std_inv) + """ + res = (primals[0] - data) * primals[1] + return 0.5 * sum_of_squares(res) - jnp.sum(jnp.log(primals[1])) + + def metric(primals, tangents): + """ + primals, tangent : pair of (mean, std_inv) + """ + prim_std_inv_sq = primals[1]**2 + res = (prim_std_inv_sq * tangents[0], 2 * tangents[1] / prim_std_inv_sq) + return type(primals)(res) + + def left_sqrt_metric(primals, tangents): + """ + primals, tangent : pair of (mean, std_inv) + """ + res = (primals[1] * tangents[0], jnp.sqrt(2) * tangents[1] / primals[1]) + return type(primals)(res) + + def transformation(primals): + """ + pirmals : pair of (mean, std_inv) + + Notes + ----- + A global transformation to Euclidean space does not exist. A local + approximation invoking the residual is used instead. + """ + # TODO: test by drawing synthetic data that actually follows the + # noise-cov and then average over it + res = (primals[1] * (primals[0] - data), tree_map(jnp.log, primals[1])) + return type(primals)(res) + + lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, (data, data)) + + return Likelihood( + hamiltonian, + transformation=transformation, + left_sqrt_metric=left_sqrt_metric, + metric=metric, + lsm_tangents_shape=lsm_tangents_shape + ) + + +def VariableCovarianceStudentT(data, dof): + """Student's t likelihood of the data with a variable covariance + + Parameters + ---------- + data : tree-like structure of jnp.ndarray and float + Data with additive noise following a Gaussian distribution. + dof : tree-like structure of jnp.ndarray and float + Degree-of-freedom parameter of Student's t distribution. + + Notes + ----- + The likelihood acts on a tuple of (mean, std). + """ + def hamiltonian(primals): + """ + primals : pair of (mean, std) + """ + t = standard_t((data - primals[0]) / primals[1], dof) + t += jnp.sum(jnp.log(primals[1])) + return t + + def metric(primals, tangent): + """ + primals, tangent : pair of (mean, std) + """ + return ( + tangent[0] * (dof + 1) / (dof + 3) / primals[1]**2, + tangent[1] * 2 * dof / (dof + 3) / primals[1]**2 + ) + + def left_sqrt_metric(primals, tangents): + """ + primals, tangents : pair of (mean, std) + """ + cov = ( + (dof + 1) / (dof + 3) / primals[1]**2, + 2 * dof / (dof + 3) / primals[1]**2 + ) + res = (jnp.sqrt(cov[0]) * tangents[0], jnp.sqrt(cov[1]) * tangents[1]) + return res + + lsm_tangents_shape = tree_map(ShapeWithDtype.from_leave, (data, data)) + + return Likelihood( + hamiltonian, + left_sqrt_metric=left_sqrt_metric, + metric=metric, + lsm_tangents_shape=lsm_tangents_shape + ) + + +def Categorical(data, axis=-1, sampling_dtype=float): + """Categorical likelihood of the data, equivalent to cross entropy + + Parameters + ---------- + data : sequence of int + An array stating which of the categories is the realized in the data. + Must agree with the input shape except for the shape[axis] + axis : int + Axis over which the categories are formed + sampling_dtype : dtype, optional + Data-type for sampling. + """ + def hamiltonian(primals): + from jax.nn import log_softmax + logits = log_softmax(primals, axis=axis) + return -jnp.sum(jnp.take_along_axis(logits, data, axis)) + + def metric(primals, tangents): + from jax.nn import softmax + + preds = softmax(primals, axis=axis) + norm_term = jnp.sum(preds * tangents, axis=axis, keepdims=True) + return preds * tangents - preds * norm_term + + def left_sqrt_metric(primals, tangents): + from jax.nn import softmax + + sqrtp = jnp.sqrt(softmax(primals, axis=axis)) + norm_term = jnp.sum(sqrtp * tangents, axis=axis, keepdims=True) + return sqrtp * (tangents - sqrtp * norm_term) + + lsm_tangents_shape = tree_map(_shape_w_fixed_dtype(sampling_dtype), data) + + return Likelihood( + hamiltonian, + left_sqrt_metric=left_sqrt_metric, + metric=metric, + lsm_tangents_shape=lsm_tangents_shape + ) diff --git a/src/re/field.py b/src/re/field.py new file mode 100644 index 0000000000000000000000000000000000000000..eae10b7d08dcf6a9e5848186cb2ebdf2fedf9a3a --- /dev/null +++ b/src/re/field.py @@ -0,0 +1,274 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +import operator +from jax import numpy as jnp +from jax.tree_util import ( + register_pytree_node_class, tree_leaves, tree_map, tree_structure +) + + +def _value_op(op, name=None): + def value_call(lhs, *args, **kwargs): + return op(lhs.val, *args, **kwargs) + + name = op.__name__ if name is None else name + value_call.__name__ = f"__{name}__" + return value_call + + +def _unary_op(op, name=None): + def unary_call(lhs): + return tree_map(op, lhs) + + name = op.__name__ if name is None else name + unary_call.__name__ = f"__{name}__" + return unary_call + + +def _enforce_flags(lhs, rhs): + flags = lhs.flags if isinstance(lhs, Field) else set() + flags |= rhs.flags if isinstance(rhs, Field) else set() + if "strict_domain_checking" in flags: + ts_lhs = tree_structure(lhs) + ts_rhs = tree_structure(rhs) + + if not hasattr(rhs, "domain"): + te = f"RHS of type {type(rhs)} does not have a `domain` property" + raise TypeError(te) + if not hasattr(lhs, "domain"): + te = f"LHS of type {type(lhs)} does not have a `domain` property" + raise TypeError(te) + if rhs.domain != lhs.domain or ts_rhs != ts_lhs: + raise ValueError("domains and/or structures are incompatible") + return flags + + +def _broadcast_binary_op(op, lhs, rhs): + from itertools import repeat + + flags = _enforce_flags(lhs, rhs) + + ts_lhs = tree_structure(lhs) + ts_rhs = tree_structure(rhs) + # Catch non-objects scalars and 0d array-likes with a `ndim` property + if jnp.isscalar(lhs) or getattr(lhs, "ndim", -1) == 0: + lhs = ts_rhs.unflatten(repeat(lhs, ts_rhs.num_leaves)) + elif jnp.isscalar(rhs) or getattr(rhs, "ndim", -1) == 0: + rhs = ts_lhs.unflatten(repeat(rhs, ts_lhs.num_leaves)) + elif ts_lhs.num_nodes != ts_rhs.num_nodes: + ve = f"invalid binary operation {op} for {ts_lhs!r} and {ts_rhs!r}" + raise ValueError(ve) + + out = tree_map(op, lhs, rhs) + if flags != set(): + out._flags = flags + return out + + +def _binary_op(op, name=None): + def binary_call(lhs, rhs): + return _broadcast_binary_op(op, lhs, rhs) + + name = op.__name__ if name is None else name + binary_call.__name__ = f"__{name}__" + return binary_call + + +def _rev_binary_op(op, name=None): + def binary_call(lhs, rhs): + return _broadcast_binary_op(op, rhs, lhs) + + name = op.__name__ if name is None else name + binary_call.__name__ = f"__r{name}__" + return binary_call + + +def _fwd_rev_binary_op(op, name=None): + return (_binary_op(op, name=name), _rev_binary_op(op, name=name)) + + +def matmul(lhs, rhs): + """Returns the dot product of the two fields. + + Parameters + ---------- + lhs : object + Arbitrary, flatten-able objects. + other : object + Arbitrary, flatten-able objects. + + Returns + ------- + out : float + Dot product of fields. + """ + from .forest_util import dot + + _enforce_flags(lhs, rhs) + + ts_lhs = tree_structure(lhs) + ts_rhs = tree_structure(rhs) + if ts_lhs.num_nodes != ts_rhs.num_nodes: + ve = f"invalid operation for {ts_lhs!r} and {ts_rhs!r}" + raise ValueError(ve) + + return dot(lhs, rhs) + + +dot = matmul + + +@register_pytree_node_class +class Field(): + """Value storage for arbitrary objects with added numerics.""" + supported_flags = {"strict_domain_checking"} + + def __init__(self, val, domain=None, flags=None): + """Instantiates a field. + + Parameters + ---------- + val : object + Arbitrary, flatten-able objects. + domain : dict or None, optional + Domain of the field, e.g. with description of modes and volume. + flags : set, str or None, optional + Capabilities and constraints of the field. + """ + self._val = val + self._domain = {} if domain is None else dict(domain) + + flags = (flags, ) if isinstance(flags, str) else flags + flags = set() if flags is None else set(flags) + if not flags.issubset(Field.supported_flags): + ve = ( + f"specified flags ({flags!r}) are not a subset of the" + f" supported flags ({Field.supported_flags!r})" + ) + raise ValueError(ve) + self._flags = flags + + def tree_flatten(self): + """Recipe for flattening fields. + + Returns + ------- + flat_tree : tuple of two tuples + Pair of an iterable with the children to be flattened recursively, + and some opaque auxiliary data. + """ + return ((self._val, ), (self._domain, self._flags)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recipe to construct fields from flattened Pytrees. + + Parameters + ---------- + aux_data : tuple of a dict and a set + Opaque auxiliary data describing a field. + children: tuple + Value of the field, i.e. unflattened children. + + Returns + ------- + unflattened_tree : :class:`nifty8.field.Field` + Re-constructed field. + """ + return cls(*children, domain=aux_data[0], flags=aux_data[1]) + + @property + def val(self): + """Retrieves a **view** of the field's values.""" + return self._val + + @property + def domain(self): + """Retrieves a **copy** of the field's domain.""" + return self._domain.copy() + + @property + def flags(self): + """Retrieves a **copy** of the field's flags.""" + return self._flags.copy() + + @property + def size(self): + from .forest_util import size + + return size(self) + + def __str__(self): + s = f"Field(\n{self.val}" + if self._domain: + s += f",\ndomain={self._domain}" + if self._flags: + s += f",\nflags={self._flags}" + s += ")" + return s + + def __repr__(self): + s = f"Field(\n{self.val!r}" + if self._domain: + s += f",\ndomain={self._domain!r}" + if self._flags: + s += f",\nflags={self._flags!r}" + s += ")" + return s + + def ravel(self): + return tree_map(jnp.ravel, self) + + def __bool__(self): + return bool(self.val) + + def __hash__(self): + return hash(tuple(tree_leaves(self))) + + # NOTE, this partly redundant code could be abstracted away using + # `setattr`. However, static code analyzers will not be able to infer the + # properties then. + + __add__, __radd__ = _fwd_rev_binary_op(operator.add) + __sub__, __rsub__ = _fwd_rev_binary_op(operator.sub) + __mul__, __rmul__ = _fwd_rev_binary_op(operator.mul) + __truediv__, __rtruediv__ = _fwd_rev_binary_op(operator.truediv) + __floordiv__, __rfloordiv__ = _fwd_rev_binary_op(operator.floordiv) + __pow__, __rpow__ = _fwd_rev_binary_op(operator.pow) + __mod__, __rmod__ = _fwd_rev_binary_op(operator.mod) + __matmul__ = __rmatmul__ = matmul # arguments of matmul commute + + def __divmod__(self, other): + return self // other, self % other + + def __rdivmod__(self, other): + return other // self, other % self + + __or__, __ror__ = _fwd_rev_binary_op(operator.or_, "or") + __xor__, __rxor__ = _fwd_rev_binary_op(operator.xor) + __and__, __rand__ = _fwd_rev_binary_op(operator.and_, "and") + __lshift__, __rlshift__ = _fwd_rev_binary_op(operator.lshift) + __rshift__, __rrshift__ = _fwd_rev_binary_op(operator.rshift) + + __lt__ = _binary_op(operator.lt) + __le__ = _binary_op(operator.le) + __eq__ = _binary_op(operator.eq) + __ne__ = _binary_op(operator.ne) + __ge__ = _binary_op(operator.ge) + __gt__ = _binary_op(operator.gt) + + __neg__ = _unary_op(operator.neg) + __pos__ = _unary_op(operator.pos) + __abs__ = _unary_op(operator.abs) + __invert__ = _unary_op(operator.invert) + + conj = conjugate = _unary_op(jnp.conj) + real = _unary_op(jnp.real) + imag = _unary_op(jnp.imag) + dot = matmul + + __getitem__ = _value_op(operator.getitem) + __contains__ = _value_op(operator.contains) + __len__ = _value_op(len) + __iter__ = _value_op(iter) diff --git a/src/re/forest_util.py b/src/re/forest_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5a493fb5a711e1126b1d6341d25c551166fb973c --- /dev/null +++ b/src/re/forest_util.py @@ -0,0 +1,403 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial, reduce +import operator +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union + +from jax import lax +from jax import numpy as jnp +from jax.tree_util import ( + all_leaves, + tree_leaves, + tree_map, + tree_reduce, + tree_structure, + tree_transpose, +) +import numpy as np + +from .field import Field +from .sugar import is1d + + +def split(mappable, keys): + """Split a dictionary into one containing only the specified keys and one + with all of the remaining ones. + """ + sel, rest = {}, {} + for k, v in mappable.items(): + if k in keys: + sel[k] = v + else: + rest[k] = v + return sel, rest + + +def unite(x, y, op=operator.add): + """Unites two array-, dict- or Field-like objects. + + If a key is contained in both objects, then the fields at that key + are combined. + """ + if isinstance(x, Field) or isinstance(y, Field): + x = x.val if isinstance(x, Field) else x + y = y.val if isinstance(y, Field) else y + return Field(unite(x, y, op=op)) + if not hasattr(x, "keys") and not hasattr(y, "keys"): + return op(x, y) + if not hasattr(x, "keys") or not hasattr(y, "keys"): + te = ( + "one of the inputs does not have a `keys` property;" + f" got {type(x)} and {type(y)}" + ) + raise TypeError(te) + + out = {} + for k in x.keys() | y.keys(): + if k in x and k in y: + out[k] = op(x[k], y[k]) + elif k in x: + out[k] = x[k] + else: + out[k] = y[k] + return out + + +CORE_ARITHMETIC_ATTRIBUTES = ( + "__neg__", "__pos__", "__abs__", "__add__", "__radd__", "__sub__", + "__rsub__", "__mul__", "__rmul__", "__truediv__", "__rtruediv__", + "__floordiv__", "__rfloordiv__", "__pow__", "__rpow__", "__mod__", + "__rmod__", "__matmul__", "__rmatmul__" +) + + +def has_arithmetics(obj, additional_attributes=()): + desired_attrs = CORE_ARITHMETIC_ATTRIBUTES + additional_attributes + return all(hasattr(obj, attr) for attr in desired_attrs) + + +def assert_arithmetics(obj, *args, **kwargs): + if not has_arithmetics(obj, *args, **kwargs): + ae = ( + f"input of type {type(obj)} does not support" + " core arithmetic operations" + "\nmaybe you forgot to wrap your object in a" + " :class:`nifty8.re.field.Field` instance" + ) + raise AssertionError(ae) + + +class ShapeWithDtype(): + """Minimal helper class storing the shape and dtype of an object. + + Notes + ----- + This class may not be transparent to JAX as it shall not be flattened + itself. If used in a tree-like structure. It should only be used as leave. + """ + def __init__(self, shape: Union[Tuple[int], List[int], int], dtype=None): + """Instantiates a storage unit for shape and dtype. + + Parameters + ---------- + shape : tuple or list of int + One-dimensional sequence of integers denoting the length of the + object along each of the object's axis. + dtype : dtype + Data-type of the to-be-described object. + """ + if isinstance(shape, int): + shape = (shape, ) + if isinstance(shape, list): + shape = tuple(shape) + if not is1d(shape): + ve = f"invalid shape; got {shape!r}" + raise TypeError(ve) + + self._shape = shape + self._dtype = jnp.float64 if dtype is None else dtype + self._size = None + + @classmethod + def from_leave(cls, element): + """Convenience method for creating an instance of `ShapeWithDtype` from + an object. + + To map a whole tree-like structure to a its shape and dtype use JAX's + `tree_map` method like so: + + tree_map(ShapeWithDtype.from_leave, tree) + + Parameters + ---------- + element : tree-like structure + Object from which to take the shape and data-type. + + Returns + ------- + swd : instance of ShapeWithDtype + Instance storing the shape and data-type of `element`. + """ + if not all_leaves((element, )): + ve = "tree is not flat and still contains leaves" + raise ValueError(ve) + return cls(jnp.shape(element), get_dtype(element)) + + @property + def shape(self) -> Tuple[int]: + """Retrieves the shape.""" + return self._shape + + @property + def dtype(self): + """Retrieves the data-type.""" + return self._dtype + + @property + def size(self) -> int: + """Total number of elements.""" + if self._size is None: + self._size = reduce(operator.mul, self.shape, 1) + return self._size + + @property + def ndim(self) -> int: + return len(self.shape) + + def __len__(self) -> int: + if self.ndim > 0: + return self.shape[0] + else: # mimic numpy + raise TypeError("len() of unsized object") + + def __eq__(self, other) -> bool: + if not isinstance(other, ShapeWithDtype): + return False + else: + return (self.shape, self.dtype) == (other.shape, other.dtype) + + def __repr__(self): + nm = self.__class__.__name__ + return f"{nm}(shape={self.shape}, dtype={self.dtype})" + + +def get_dtype(v: Any): + if hasattr(v, "dtype"): + return v.dtype + else: + return type(v) + + +def common_type(*trees): + from numpy import find_common_type + + common_dtp = find_common_type( + tuple( + find_common_type(tuple(get_dtype(v) for v in tree_leaves(tr)), ()) + for tr in trees + ), () + ) + return common_dtp + + +def _size(x): + return x.size if hasattr(x, "size") else jnp.size(x) + + +def size(tree, axis: Optional[int] = None) -> int: + if axis is not None: + raise TypeError("axis of an arbitrary tree is ill defined") + sizes = tree_map(_size, tree) + return tree_reduce(operator.add, sizes) + + +def _shape(x): + return x.shape if hasattr(x, "shape") else jnp.shape(x) + + +T = TypeVar("T") + + +def shape(tree: T) -> T: + return tree_map(_shape, tree) + + +def _zeros_like(x, dtype, shape): + if hasattr(x, "shape") and hasattr(x, "dtype"): + shp = x.shape if shape is None else shape + dtp = x.dtype if dtype is None else dtype + return jnp.zeros(shape=shp, dtype=dtp) + return jnp.zeros_like(x, dtype=dtype, shape=shape) + + +def zeros_like(a, dtype=None, shape=None): + return tree_map(partial(_zeros_like, dtype=dtype, shape=shape), a) + + +def norm(tree, ord, *, ravel: bool): + from jax.numpy.linalg import norm + + if ravel: + + def el_norm(x): + return jnp.abs(x) if jnp.ndim(x) == 0 else norm(x.ravel(), ord=ord) + else: + + def el_norm(x): + return jnp.abs(x) if jnp.ndim(x) == 0 else norm(x, ord=ord) + + return norm(tree_leaves(tree_map(el_norm, tree)), ord=ord) + + +def _ravel(x): + return x.ravel() if hasattr(x, "ravel") else jnp.ravel(x) + + +def dot(a, b, *, precision=None): + tree_of_dots = tree_map( + lambda x, y: jnp.dot(_ravel(x), _ravel(y), precision=precision), a, b + ) + return tree_reduce(operator.add, tree_of_dots, 0.) + + +def vdot(a, b, *, precision=None): + tree_of_vdots = tree_map( + lambda x, y: jnp.vdot(_ravel(x), _ravel(y), precision=precision), a, b + ) + return tree_reduce(jnp.add, tree_of_vdots, 0.) + + +def select(pred, on_true, on_false): + return tree_map(partial(lax.select, pred), on_true, on_false) + + +def where(condition, x, y): + """Selects a pytree based on the condition which can be a pytree itself. + + Notes + ----- + If `condition` is not a pytree, then a partially evaluated selection is + simply mapped over `x` and `y` without actually broadcasting `condition`. + """ + import numpy as np + from itertools import repeat + + ts_c = tree_structure(condition) + ts_x = tree_structure(x) + ts_y = tree_structure(y) + ts_max = (ts_c, ts_x, ts_y)[np.argmax( + [ts_c.num_nodes, ts_x.num_nodes, ts_y.num_nodes] + )] + + if ts_x.num_nodes < ts_max.num_nodes: + if ts_x.num_nodes > 1: + raise ValueError("can not broadcast LHS") + x = ts_max.unflatten(repeat(x, ts_max.num_leaves)) + if ts_y.num_nodes < ts_max.num_nodes: + if ts_y.num_nodes > 1: + raise ValueError("can not broadcast RHS") + y = ts_max.unflatten(repeat(y, ts_max.num_leaves)) + + if ts_c.num_nodes < ts_max.num_nodes: + if ts_c.num_nodes > 1: + raise ValueError("can not map condition") + return tree_map(partial(jnp.where, condition), x, y) + return tree_map(jnp.where, condition, x, y) + + +def stack(arrays): + return tree_map(lambda *el: jnp.stack(el), *arrays) + + +def unstack(stack): + element_count = tree_leaves(stack)[0].shape[0] + split = partial(jnp.split, indices_or_sections=element_count) + unstacked = tree_transpose( + tree_structure(stack), tree_structure((0., ) * element_count), + tree_map(split, stack) + ) + return tree_map(partial(jnp.squeeze, axis=0), unstacked) + + +def map_forest( + f: Callable, + in_axes: Union[int, Tuple] = 0, + out_axes: Union[int, Tuple] = 0, + tree_transpose_output: bool = True, + mapping: Union[str, Callable] = 'vmap', + **kwargs +) -> Callable: + from jax import vmap, pmap + + if out_axes != 0: + raise TypeError("`out_axis` not yet supported") + in_axes = in_axes if isinstance(in_axes, tuple) else (in_axes, ) + i = None + for idx, el in enumerate(in_axes): + if el is not None and i is None: + i = idx + elif el is not None and i is not None: + ve = "mapping over more than one axis is not yet supported" + raise ValueError(ve) + if i is None: + raise ValueError("must map over at least one axis") + if not isinstance(i, int): + te = "mapping over a non integer axis is not yet supported" + raise TypeError(te) + + if isinstance(mapping, str): + if mapping == 'vmap' or mapping == 'v': + f_map = vmap(f, in_axes=in_axes, out_axes=out_axes, **kwargs) + elif mapping == 'pmap' or mapping == 'p': + f_map = pmap(f, in_axes=in_axes, out_axes=out_axes, **kwargs) + elif mapping == 'lax.map' or mapping == 'lax': + if all(el == 0 + for el in in_axes) and np.all(0 == np.array(out_axes)): + f_map = partial(lax.map, f) + else: + ve = ( + "mapping `in_axes` and `out_axes` along another axis than" + " the 0-axis is not possible for `lax.map`" + ) + raise ValueError(ve) + else: + ve = ( + f"{mapping} is not an accepted key to a mapping function" + "; please pass function directly" + ) + raise ValueError(ve) + elif callable(mapping): + f_map = mapping(f, in_axes=in_axes, out_axes=out_axes, **kwargs) + else: + te = ( + f"invalid `mapping` of type {type(mapping)!r}" + "; expected string or callable" + ) + raise TypeError(te) + + def apply(*xs): + if not isinstance(xs[i], (list, tuple)): + te = f"expected mapped axes to be a tuple; got {type(xs[i])}" + raise TypeError(te) + x_T = stack(xs[i]) + + out_T = f_map(*xs[:i], x_T, *xs[i + 1:]) + # Since `out_axes` is forced to be `0`, we don't need to worry about + # transposing only part of the output + if not tree_transpose_output: + return out_T + return unstack(out_T) + + return apply + + +def map_forest_mean(method, mapping='vmap', *args, **kwargs) -> Callable: + method_map = map_forest( + method, *args, tree_transpose_output=False, mapping=mapping, **kwargs + ) + + def meaned_apply(*xs, **xs_kw): + return tree_map(partial(jnp.mean, axis=0), method_map(*xs, **xs_kw)) + + return meaned_apply diff --git a/src/re/hmc.py b/src/re/hmc.py new file mode 100644 index 0000000000000000000000000000000000000000..e3bf05ae5def3afc17fd1bab9480857f89922d8b --- /dev/null +++ b/src/re/hmc.py @@ -0,0 +1,630 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +from typing import Callable, NamedTuple, TypeVar, Union + +from jax import numpy as jnp +from jax import random, tree_util +from jax.experimental import host_callback +from jax.lax import population_count +from jax.scipy.special import expit + +from .disable_jax_control_flow import cond, fori_loop, while_loop +from .forest_util import select +from .sugar import random_like + +_DEBUG_FLAG = False + +_DEBUG_TREE_END_IDXS = [] +_DEBUG_SUBTREE_END_IDXS = [] +_DEBUG_STORE = [] + + +def _DEBUG_ADD_QP(qp): + """Stores **all** results of leapfrog integration""" + global _DEBUG_STORE + _DEBUG_STORE.append(qp) + + +def _DEBUG_FINISH_TREE(dummy_arg): + """Signal the position of a finished tree in `_DEBUG_STORE`""" + global _DEBUG_TREE_END_IDXS + _DEBUG_TREE_END_IDXS.append(len(_DEBUG_STORE)) + + +def _DEBUG_FINISH_SUBTREE(dummy_arg): + """Signal the position of a finished sub-tree in `_DEBUG_STORE`""" + global _DEBUG_SUBTREE_END_IDXS + _DEBUG_SUBTREE_END_IDXS.append(len(_DEBUG_STORE)) + + +### COMMON FUNCTIONALITY +Q = TypeVar("Q") + + +class QP(NamedTuple): + """Object holding a pair of position and momentum. + + Attributes + ---------- + position : Q + Position. + momentum : Q + Momentum. + """ + position: Q + momentum: Q + + +def flip_momentum(qp: QP) -> QP: + return QP(position=qp.position, momentum=-qp.momentum) + + +def sample_momentum_from_diagonal(*, key, mass_matrix_sqrt): + """ + Draw a momentum sample from the kinetic energy of the hamiltonian. + + Parameters + ---------- + key: ndarray + PRNGKey used as the random key. + mass_matrix_sqrt: ndarray + The left square-root mass matrix (i.e. square-root of the inverse + diagonal covariance) to use for sampling. Diagonal matrix represented + as (possibly pytree of) ndarray vector containing the entries of the + diagonal. + """ + normal = random_like(key=key, primals=mass_matrix_sqrt, rng=random.normal) + return tree_util.tree_map(jnp.multiply, mass_matrix_sqrt, normal) + + +# TODO: how to randomize step size (neal sect. 3.2) +# @partial(jit, static_argnames=('potential_energy_gradient',)) +def leapfrog_step( + potential_energy_gradient, + kinetic_energy_gradient, + step_size, + inverse_mass_matrix, + qp: QP, +): + """ + Perform one iteration of the leapfrog integrator forwards in time. + + Parameters + ---------- + potential_energy_gradient: Callable[[ndarray], float] + Potential energy gradient part of the hamiltonian (V). Depends on + position only. + qp: QP + Point in position and momentum space from which to start integration. + step_size: float + Step length (usually called epsilon) of the leapfrog integrator. + """ + position = qp.position + momentum = qp.momentum + + momentum_halfstep = ( + momentum - (step_size / 2.) * potential_energy_gradient(position) + ) + + position_fullstep = position + step_size * kinetic_energy_gradient( + inverse_mass_matrix, momentum_halfstep + ) + + momentum_fullstep = ( + momentum_halfstep - + (step_size / 2.) * potential_energy_gradient(position_fullstep) + ) + + qp_fullstep = QP(position=position_fullstep, momentum=momentum_fullstep) + + global _DEBUG_FLAG + if _DEBUG_FLAG: + # append result to global list variable + host_callback.call(_DEBUG_ADD_QP, qp_fullstep) + + return qp_fullstep + + +### SIMPLE HMC +class AcceptedAndRejected(NamedTuple): + accepted_qp: QP + rejected_qp: QP + accepted: Union[jnp.ndarray, bool] + diverging: Union[jnp.ndarray, bool] + + +# @partial(jit, static_argnames=('potential_energy', 'potential_energy_gradient')) +def generate_hmc_acc_rej( + *, key, initial_qp, potential_energy, kinetic_energy, inverse_mass_matrix, + stepper, num_steps, step_size, max_energy_difference +) -> AcceptedAndRejected: + """ + Generate a sample given the initial position. + + Parameters + ---------- + key: ndarray + a PRNGKey used as the random key + position: ndarray + The the starting position of this step of the markov chain. + potential_energy: Callable[[ndarray], float] + The potential energy, which is the distribution to be sampled from. + mass_matrix: ndarray + The mass matrix used in the kinetic energy + num_steps: int + The number of steps the leapfrog integrator should perform. + step_size: float + The step size (usually epsilon) for the leapfrog integrator. + """ + loop_body = partial(stepper, step_size, inverse_mass_matrix) + new_qp = fori_loop( + lower=0, + upper=num_steps, + body_fun=lambda _, args: loop_body(args), + init_val=initial_qp + ) + # this flipping is needed to make the proposal distribution symmetric + # doesn't have any effect on acceptance though because kinetic energy depends on momentum^2 + # might have an effect with other kinetic energies though + proposed_qp = flip_momentum(new_qp) + + total_energy = partial( + total_energy_of_qp, + potential_energy=potential_energy, + kinetic_energy_w_inv_mass=partial(kinetic_energy, inverse_mass_matrix) + ) + energy_diff = total_energy(initial_qp) - total_energy(proposed_qp) + energy_diff = jnp.where(jnp.isnan(energy_diff), jnp.inf, energy_diff) + transition_probability = jnp.minimum(1., jnp.exp(energy_diff)) + + accept = random.bernoulli(key, transition_probability) + accepted_qp, rejected_qp = select( + accept, + (proposed_qp, initial_qp), + (initial_qp, proposed_qp), + ) + diverging = jnp.abs(energy_diff) > max_energy_difference + return AcceptedAndRejected( + accepted_qp, rejected_qp, accepted=accept, diverging=diverging + ) + + +### NUTS +class Tree(NamedTuple): + """Object carrying tree metadata. + + Attributes + ---------- + left, right : QP + Respective endpoints of the trees path. + logweight: Union[jnp.ndarray, float] + Sum over all -H(q, p) in the tree's path. + proposal_candidate: QP + Sample from the trees path, distributed as exp(-H(q, p)). + turning: Union[jnp.ndarray, bool] + Indicator for either the left or right endpoint are a uturn or any + subtree is a uturn. + diverging: Union[jnp.ndarray, bool] + Indicator for a large increase in energy in the next larger tree. + depth: Union[jnp.ndarray, int] + Levels of the tree. + cumulative_acceptance: Union[jnp.ndarray, float] + Sum of all acceptance probabilities relative to some initial energy + value. This value is distinct from `logweight` as its absolute value is + only well defined for the very final tree of NUTS. + """ + left: QP + right: QP + logweight: Union[jnp.ndarray, float] + proposal_candidate: QP + turning: Union[jnp.ndarray, bool] + diverging: Union[jnp.ndarray, bool] + depth: Union[jnp.ndarray, int] + cumulative_acceptance: Union[jnp.ndarray, float] + + +def total_energy_of_qp(qp, potential_energy, kinetic_energy_w_inv_mass): + return potential_energy(qp.position + ) + kinetic_energy_w_inv_mass(qp.momentum) + + +def generate_nuts_tree( + initial_qp, + key, + step_size, + max_tree_depth, + stepper: Callable[[Union[jnp.ndarray, float], Q, QP], QP], + potential_energy, + kinetic_energy: Callable[[Q, Q], float], + inverse_mass_matrix: Q, + bias_transition: bool = True, + max_energy_difference: Union[jnp.ndarray, float] = jnp.inf +) -> Tree: + """Generate a sample given the initial position. + + This call implements a No-U-Turn-Sampler. + + Parameters + ---------- + initial_qp: QP + Starting pair of (position, momentum). **NOTE**, the momentum must be + resampled from conditional distribution **BEFORE** passing it into this + function! + key: ndarray + PRNGKey used as the random key. + step_size: float + Step size (usually called epsilon) for the leapfrog integrator. + max_tree_depth: int + The maximum depth of the trajectory tree before the expansion is + terminated. At the maximum iteration depth, the current value is + returned even if the U-turn condition is not met. The maximum number of + points (/integration steps) per trajectory is :math:`N = + 2^{\\mathrm{max\\_tree\\_depth}}`. This function requires memory linear + in max_tree_depth, i.e. logarithmic in trajectory length. It is used to + statically allocate memory in advance. + stepper: Callable[[float, Q, QP], QP] + The function that performs (Leapfrog) steps. Takes as arguments (in order) + 1) step size (containing the direction): float , + 2) inverse mass matrix: Q , + 3) starting point: QP . + potential_energy: Callable[[Q], float] + The potential energy, of the distribution to be sampled from. Takes + only the position part (QP.position) as argument. + kinetic_energy: Callable[[Q, Q], float], optional + Mapping of the momentum to its corresponding kinetic energy. As + argument the function takes the inverse mass matrix and the momentum. + + Returns + ------- + current_tree: Tree + The final tree, carrying a sample from the target distribution. + + See Also + -------- + No-U-Turn Sampler original paper (2011): https://arxiv.org/abs/1111.4246 + NumPyro Iterative NUTS paper: https://arxiv.org/abs/1912.11554 + Combination of samples from two trees, Sampling from trajectories according to target distribution in this paper's Appendix: https://arxiv.org/abs/1701.02434 + """ + # initialize depth 0 tree, containing 2**0 = 1 points + initial_neg_energy = -total_energy_of_qp( + initial_qp, potential_energy, + partial(kinetic_energy, inverse_mass_matrix) + ) + current_tree = Tree( + left=initial_qp, + right=initial_qp, + logweight=initial_neg_energy, + proposal_candidate=initial_qp, + turning=False, + diverging=False, + depth=0, + cumulative_acceptance=jnp.zeros_like(initial_neg_energy) + ) + + def _cont_cond(loop_state): + _, current_tree, stop = loop_state + return (~stop) & (current_tree.depth <= max_tree_depth) + + def cond_tree_doubling(loop_state): + key, current_tree, _ = loop_state + key, key_dir, key_subtree, key_merge = random.split(key, 4) + + go_right = random.bernoulli(key_dir, 0.5) + + # build tree adjacent to current_tree + new_subtree = iterative_build_tree( + key_subtree, + current_tree, + step_size, + go_right, + stepper, + potential_energy, + kinetic_energy, + inverse_mass_matrix, + max_tree_depth, + initial_neg_energy=initial_neg_energy, + max_energy_difference=max_energy_difference + ) + # Mark current tree as diverging if it diverges in the next step + current_tree = current_tree._replace(diverging=new_subtree.diverging) + + # combine current_tree and new_subtree into a tree which is one layer deeper only if new_subtree has no turning subtrees (including itself) + current_tree = cond( + # If new tree is turning or diverging, do not merge + pred=new_subtree.turning | new_subtree.diverging, + true_fun=lambda old_and_new: old_and_new[0], + false_fun=lambda old_and_new: merge_trees( + key_merge, + old_and_new[0], + old_and_new[1], + go_right, + bias_transition=bias_transition + ), + operand=(current_tree, new_subtree), + ) + # stop if new subtree was turning -> we sample from the old one and don't expand further + # stop if new total tree is turning -> we sample from the combined trajectory and don't expand further + stop = new_subtree.turning | current_tree.turning + stop |= new_subtree.diverging + return (key, current_tree, stop) + + loop_state = (key, current_tree, False) + _, current_tree, _ = while_loop(_cont_cond, cond_tree_doubling, loop_state) + + global _DEBUG_FLAG + if _DEBUG_FLAG: + host_callback.call(_DEBUG_FINISH_TREE, None) + + return current_tree + + +def tree_index_get(ptree, idx): + return tree_util.tree_map(lambda arr: arr[idx], ptree) + + +def tree_index_update(x, idx, y): + from jax.tree_util import tree_map + + return tree_map(lambda x_el, y_el: x_el.at[idx].set(y_el), x, y) + + +def count_trailing_ones(n): + """Count the number of trailing, consecutive ones in the binary + representation of `n`. + + Warning + ------- + `n` must be positive and strictly smaller than 2**64 + + Examples + -------- + >>> print(bin(23), count_trailing_one_bits(23)) + 0b10111 3 + """ + # taken from http://num.pyro.ai/en/stable/_modules/numpyro/infer/hmc_util.html + _, trailing_ones_count = while_loop( + lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0) + ) + return trailing_ones_count + + +def is_euclidean_uturn(qp_left, qp_right): + """ + See Also + -------- + Betancourt - A conceptual introduction to Hamiltonian Monte Carlo + """ + return ( + (qp_right.momentum.dot(qp_right.position - qp_left.position) < 0.) & + (qp_left.momentum.dot(qp_left.position - qp_right.position) < 0.) + ) + + +# Essentially algorithm 2 from https://arxiv.org/pdf/1912.11554.pdf +def iterative_build_tree( + key, initial_tree, step_size, go_right, stepper, potential_energy, + kinetic_energy, inverse_mass_matrix, max_tree_depth, initial_neg_energy, + max_energy_difference +): + """ + Starting from either the left or right endpoint of a given tree, builds a + new adjacent tree of the same size. + + Parameters + ---------- + key: ndarray + PRNGKey to choose a sample when adding QPs to the tree. + initial_tree: Tree + Tree to be extended (doubled) on the left or right. + step_size: float + The step size (usually called epsilon) for the leapfrog integrator. + go_right: bool + If `go_right` start at the right end, going right else start at the + left end, going left. + stepper: Callable[[float, Q, QP], QP] + The function that performs (Leapfrog) steps. Takes as arguments (in order) + 1) step size (containing the direction): float , + 2) inverse mass matrix: Q , + 3) starting point: QP . + potential_energy: Callable[[Q], float] + Potential energy, of the distribution to be sampled from. Takes + only the position part (QP.position) as argument. + kinetic_energy: Callable[[Q, Q], float], optional + Mapping of the momentum to its corresponding kinetic energy. As + argument the function takes the inverse mass matrix and the momentum. + max_tree_depth: int + An upper bound on the 'depth' argument, but has no effect on the + functions behaviour. It's only required to statically set the size of + the `S` array (Q). + """ + # 1. choose start point of integration + z = select(go_right, initial_tree.right, initial_tree.left) + depth = initial_tree.depth + max_num_proposals = 2**depth + # 2. build / collect new states + # Create a storage for left endpoints of subtrees. Size is determined + # statically by the `max_tree_depth` parameter. + # NOTE, let's hope this does not break anything but in principle we only + # need `max_tree_depth` element even though the tree can be of length `max_tree_depth + + # 1`. This is because we will never access the last element. + S = tree_util.tree_map( + lambda proto: jnp. + empty_like(proto, shape=(max_tree_depth, ) + jnp.shape(proto)), z + ) + + z = stepper( + jnp.where(go_right, 1., -1.) * step_size, inverse_mass_matrix, z + ) + neg_energy = -total_energy_of_qp( + z, potential_energy, partial(kinetic_energy, inverse_mass_matrix) + ) + diverging = jnp.abs(neg_energy - initial_neg_energy) > max_energy_difference + cum_acceptance = jnp.minimum(1., jnp.exp(initial_neg_energy - neg_energy)) + incomplete_tree = Tree( + left=z, + right=z, + logweight=neg_energy, + proposal_candidate=z, + turning=False, + diverging=diverging, + depth=-1, + cumulative_acceptance=cum_acceptance + ) + S = tree_index_update(S, 0, z) + + def amend_incomplete_tree(state): + n, incomplete_tree, z, S, key = state + + key, key_choose_candidate = random.split(key) + z = stepper( + jnp.where(go_right, 1., -1.) * step_size, inverse_mass_matrix, z + ) + incomplete_tree = add_single_qp_to_tree( + key_choose_candidate, + incomplete_tree, + z, + go_right, + potential_energy, + kinetic_energy, + inverse_mass_matrix, + initial_neg_energy=initial_neg_energy, + max_energy_difference=max_energy_difference + ) + + def _even_fun(S): + # n is even, the current z is w.l.o.g. a left endpoint of some + # subtrees. Register the current z to be used in turning condition + # checks later, when the right endpoints of it's subtrees are + # generated. + S = tree_index_update(S, population_count(n), z) + return S, False + + def _odd_fun(S): + # n is odd, the current z is w.l.o.g a right endpoint of some + # subtrees. Check turning condition against all left endpoints of + # subtrees that have the current z (/n) as their right endpoint. + + # l = nubmer of subtrees that have current z as their right endpoint. + l = count_trailing_ones(n) + # inclusive indices into S referring to the left endpoints of the l subtrees. + i_max_incl = population_count(n - 1) + i_min_incl = i_max_incl - l + 1 + # TODO: this should traverse the range in reverse + turning = fori_loop( + lower=i_min_incl, + upper=i_max_incl + 1, + # TODO: conditional for early termination + body_fun=lambda k, turning: turning | + is_euclidean_uturn(tree_index_get(S, k), z), + init_val=False + ) + return S, turning + + S, turning = cond( + pred=n % 2 == 0, true_fun=_even_fun, false_fun=_odd_fun, operand=S + ) + incomplete_tree = incomplete_tree._replace(turning=turning) + return (n + 1, incomplete_tree, z, S, key) + + def _cont_cond(state): + n, incomplete_tree, *_ = state + return (n < max_num_proposals) & (~incomplete_tree.turning + ) & (~incomplete_tree.diverging) + + n, incomplete_tree, *_ = while_loop( + # while n < 2**depth and not stop + cond_fun=_cont_cond, + body_fun=amend_incomplete_tree, + init_val=(1, incomplete_tree, z, S, key) + ) + + global _DEBUG_FLAG + if _DEBUG_FLAG: + host_callback.call(_DEBUG_FINISH_SUBTREE, None) + + # The depth of a tree which was aborted early is possibly ill defined + depth = jnp.where(n == max_num_proposals, depth, -1) + return incomplete_tree._replace(depth=depth) + + +def add_single_qp_to_tree( + key, tree, qp, go_right, potential_energy, kinetic_energy, + inverse_mass_matrix, initial_neg_energy, max_energy_difference +): + """Helper function for progressive sampling. Takes a tree with a sample, and + a new endpoint, propagates sample. + """ + # This is technically just a special case of merge_trees with one of the + # trees being a singleton, depth 0 tree. However, no turning check is + # required and it is not possible to bias the transition. + left, right = select(go_right, (tree.left, qp), (qp, tree.right)) + + neg_energy = -total_energy_of_qp( + qp, potential_energy, partial(kinetic_energy, inverse_mass_matrix) + ) + diverging = jnp.abs(neg_energy - initial_neg_energy) > max_energy_difference + # ln(e^-H_1 + e^-H_2) + total_logweight = jnp.logaddexp(tree.logweight, neg_energy) + # expit(x-y) := 1 / (1 + e^(-(x-y))) = 1 / (1 + e^(y-x)) = e^x / (e^y + e^x) + prob_of_keeping_old = expit(tree.logweight - neg_energy) + remain = random.bernoulli(key, prob_of_keeping_old) + proposal_candidate = select(remain, tree.proposal_candidate, qp) + # NOTE, set an invalid depth as to indicate that adding a single QP to a + # perfect binary tree does not yield another perfect binary tree + cum_acceptance = tree.cumulative_acceptance + jnp.minimum( + 1., jnp.exp(initial_neg_energy - neg_energy) + ) + return Tree( + left, + right, + total_logweight, + proposal_candidate, + turning=tree.turning, + diverging=diverging, + depth=-1, + cumulative_acceptance=cum_acceptance + ) + + +def merge_trees(key, current_subtree, new_subtree, go_right, bias_transition): + """Merges two trees, propagating the proposal_candidate""" + # 5. decide which sample to take based on total weights (merge trees) + if bias_transition: + # Bias the transition towards the new subtree (see Betancourt + # conceptual intro (and Numpyro)) + transition_probability = jnp.minimum( + 1., jnp.exp(new_subtree.logweight - current_subtree.logweight) + ) + else: + # expit(x-y) := 1 / (1 + e^(-(x-y))) = 1 / (1 + e^(y-x)) = e^x / (e^y + e^x) + transition_probability = expit( + new_subtree.logweight - current_subtree.logweight + ) + # print(f"prob of choosing new sample: {transition_probability}") + new_sample = select( + random.bernoulli(key, transition_probability), + new_subtree.proposal_candidate, current_subtree.proposal_candidate + ) + # 6. define new tree + left, right = select( + go_right, + (current_subtree.left, new_subtree.right), + (new_subtree.left, current_subtree.right), + ) + turning = is_euclidean_uturn(left, right) + diverging = current_subtree.diverging | new_subtree.diverging + neg_energy = jnp.logaddexp(new_subtree.logweight, current_subtree.logweight) + cum_acceptance = current_subtree.cumulative_acceptance + new_subtree.cumulative_acceptance + merged_tree = Tree( + left=left, + right=right, + logweight=neg_energy, + proposal_candidate=new_sample, + turning=turning, + diverging=diverging, + depth=current_subtree.depth + 1, + cumulative_acceptance=cum_acceptance + ) + return merged_tree diff --git a/src/re/hmc_oo.py b/src/re/hmc_oo.py new file mode 100644 index 0000000000000000000000000000000000000000..5c001065cc35ae03dbf545ddad1428c4d3cb3b28 --- /dev/null +++ b/src/re/hmc_oo.py @@ -0,0 +1,355 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +from typing import Any, Callable, NamedTuple, Optional, Tuple, Union + +import numpy as np +from jax import grad +from jax import numpy as jnp +from jax import random, tree_util + +from .disable_jax_control_flow import fori_loop +from .hmc import AcceptedAndRejected, Q, QP, Tree +from .hmc import ( + generate_hmc_acc_rej, + generate_nuts_tree, + leapfrog_step, + sample_momentum_from_diagonal, + tree_index_update, +) + + +def _parse_diag_mass_matrix(mass_matrix, position_proto: Q) -> Q: + if isinstance(mass_matrix, + (float, jnp.ndarray)) and jnp.size(mass_matrix) == 1: + mass_matrix = tree_util.tree_map( + partial(jnp.full_like, fill_value=mass_matrix), position_proto + ) + elif tree_util.tree_structure(mass_matrix + ) == tree_util.tree_structure(position_proto): + shape_match_tree = tree_util.tree_map( + lambda a1, a2: jnp.shape(a1) == jnp.shape(a2), mass_matrix, + position_proto + ) + shape_and_structure_match = all( + tree_util.tree_flatten(shape_match_tree) + ) + if not shape_and_structure_match: + ve = "matrix has same tree_structe as the position but shapes do not match up" + raise ValueError(ve) + else: + te = "matrix must either be float or have same tree structure as the position" + raise TypeError(te) + + return mass_matrix + + +class Chain(NamedTuple): + """Object carrying chain metadata; think: transposed Tree with new axis. + """ + # Q but with one more dimension on the first axes of the leave tensors + samples: Q + divergences: jnp.ndarray + acceptance: Union[jnp.ndarray, float] + depths: Optional[jnp.ndarray] = None + trees: Optional[Union[Tree, AcceptedAndRejected]] = None + + +class _Sampler: + def __init__( + self, + potential_energy: Callable[[Q], Union[jnp.ndarray, float]], + inverse_mass_matrix, + position_proto: Q, + step_size: Union[jnp.ndarray, float] = 1.0, + max_energy_difference: Union[jnp.ndarray, float] = jnp.inf + ): + if not callable(potential_energy): + raise TypeError() + if not isinstance(step_size, (jnp.ndarray, float)): + raise TypeError() + + self.potential_energy = potential_energy + + self.inverse_mass_matrix = _parse_diag_mass_matrix( + inverse_mass_matrix, position_proto=position_proto + ) + self.mass_matrix_sqrt = self.inverse_mass_matrix**(-0.5) + + self.step_size = step_size + + def kinetic_energy(inverse_mass_matrix, momentum): + # NOTE, assume a diagonal mass-matrix + return inverse_mass_matrix.dot(momentum**2) / 2. + + self.kinetic_energy = kinetic_energy + kinetic_energy_gradient = lambda inv_m, mom: inv_m * mom + potential_energy_gradient = grad(self.potential_energy) + self.stepper = partial( + leapfrog_step, potential_energy_gradient, kinetic_energy_gradient + ) + + self.max_energy_difference = max_energy_difference + + def sample_next_state(key, + prev_position: Q) -> Tuple[Any, Tuple[Any, Q]]: + raise NotImplementedError() + + self.sample_next_state = sample_next_state + + @staticmethod + def init_chain( + num_samples: int, position_proto, save_intermediates: bool + ) -> Chain: + raise NotImplementedError() + + @staticmethod + def update_chain( + chain: Chain, idx: Union[jnp.ndarray, int], tree: Tree + ) -> Chain: + raise NotImplementedError() + + def generate_n_samples( + self, + key: Any, + initial_position: Q, + num_samples, + *, + save_intermediates: bool = False + ) -> Tuple[Chain, Tuple[Any, Q]]: + if not isinstance(key, (jnp.ndarray, np.ndarray)): + if isinstance(key, int): + key = random.PRNGKey(key) + else: + raise TypeError() + + chain = self.init_chain( + num_samples, initial_position, save_intermediates + ) + + def amend_chain(idx, state): + chain, core_state = state + tree, core_state = self.sample_next_state(*core_state) + chain = self.update_chain(chain, idx, tree) + return chain, core_state + + chain, core_state = fori_loop( + lower=0, + upper=num_samples, + body_fun=amend_chain, + init_val=(chain, (key, initial_position)) + ) + + return chain, core_state + + +class NUTSChain(_Sampler): + def __init__( + self, + potential_energy: Callable[[Q], Union[float, jnp.ndarray]], + inverse_mass_matrix, + position_proto: Q, + step_size: float = 1.0, + max_tree_depth: int = 10, + bias_transition: bool = True, + max_energy_difference: float = jnp.inf + ): + super().__init__( + potential_energy=potential_energy, + inverse_mass_matrix=inverse_mass_matrix, + position_proto=position_proto, + step_size=step_size, + max_energy_difference=max_energy_difference + ) + + if not isinstance(max_tree_depth, int): + raise TypeError() + self.bias_transition = bias_transition + self.max_tree_depth = max_tree_depth + + def sample_next_state(key, + prev_position: Q) -> Tuple[Tree, Tuple[Any, Q]]: + key, key_momentum, key_nuts = random.split(key, 3) + + resampled_momentum = sample_momentum_from_diagonal( + key=key_momentum, mass_matrix_sqrt=self.mass_matrix_sqrt + ) + qp = QP(position=prev_position, momentum=resampled_momentum) + + tree = generate_nuts_tree( + initial_qp=qp, + key=key_nuts, + step_size=self.step_size, + max_tree_depth=self.max_tree_depth, + stepper=self.stepper, + potential_energy=self.potential_energy, + kinetic_energy=self.kinetic_energy, + inverse_mass_matrix=self.inverse_mass_matrix, + bias_transition=self.bias_transition, + max_energy_difference=self.max_energy_difference + ) + return tree, (key, tree.proposal_candidate.position) + + self.sample_next_state = sample_next_state + + @staticmethod + def init_chain( + num_samples: int, position_proto, save_intermediates: bool + ) -> Chain: + samples = tree_util.tree_map( + lambda arr: jnp. + zeros_like(arr, shape=(num_samples, ) + jnp.shape(arr)), + position_proto + ) + depths = jnp.zeros(num_samples, dtype=jnp.uint64) + divergences = jnp.zeros(num_samples, dtype=bool) + chain = Chain( + samples=samples, + divergences=divergences, + acceptance=0., + depths=depths + ) + if save_intermediates: + _qp_proto = QP(position_proto, position_proto) + _tree_proto = Tree( + _qp_proto, + _qp_proto, + 0., + _qp_proto, + turning=True, + diverging=True, + depth=0, + cumulative_acceptance=0. + ) + trees = tree_util.tree_map( + lambda leaf: jnp. + zeros_like(leaf, shape=(num_samples, ) + jnp.shape(leaf)), + _tree_proto + ) + chain = chain._replace(trees=trees) + + return chain + + @staticmethod + def update_chain( + chain: Chain, idx: Union[jnp.ndarray, int], tree: Tree + ) -> Chain: + num_proposals = 2**jnp.array(tree.depth, dtype=jnp.uint64) - 1 + tree_acceptance = jnp.where( + num_proposals > 0, tree.cumulative_acceptance / num_proposals, 0. + ) + + samples = tree_index_update( + chain.samples, idx, tree.proposal_candidate.position + ) + divergences = chain.divergences.at[idx].set(tree.diverging) + depths = chain.depths.at[idx].set(tree.depth) + acceptance = ( + chain.acceptance + (tree_acceptance - chain.acceptance) / (idx + 1) + ) + chain = chain._replace( + samples=samples, + divergences=divergences, + acceptance=acceptance, + depths=depths + ) + if chain.trees is not None: + trees = tree_index_update(chain.trees, idx, tree) + chain = chain._replace(trees=trees) + + return chain + + +class HMCChain(_Sampler): + def __init__( + self, + potential_energy: Callable, + inverse_mass_matrix, + position_proto, + num_steps, + step_size: float = 1.0, + max_energy_difference: float = jnp.inf + ): + super().__init__( + potential_energy=potential_energy, + inverse_mass_matrix=inverse_mass_matrix, + position_proto=position_proto, + step_size=step_size, + max_energy_difference=max_energy_difference + ) + + if not isinstance(num_steps, (jnp.ndarray, int)): + raise TypeError() + self.num_steps = num_steps + + def sample_next_state(key, + prev_position: Q) -> Tuple[Tree, Tuple[Any, Q]]: + key, key_choose, key_momentum_resample = random.split(key, 3) + + resampled_momentum = sample_momentum_from_diagonal( + key=key_momentum_resample, + mass_matrix_sqrt=self.mass_matrix_sqrt + ) + qp = QP(position=prev_position, momentum=resampled_momentum) + + acc_rej = generate_hmc_acc_rej( + key=key_choose, + initial_qp=qp, + potential_energy=self.potential_energy, + kinetic_energy=self.kinetic_energy, + inverse_mass_matrix=self.inverse_mass_matrix, + stepper=self.stepper, + num_steps=self.num_steps, + step_size=self.step_size, + max_energy_difference=self.max_energy_difference + ) + return acc_rej, (key, acc_rej.accepted_qp.position) + + self.sample_next_state = sample_next_state + + @staticmethod + def init_chain( + num_samples: int, position_proto, save_intermediates: bool + ) -> Chain: + samples = tree_util.tree_map( + lambda arr: jnp. + zeros_like(arr, shape=(num_samples, ) + jnp.shape(arr)), + position_proto + ) + divergences = jnp.zeros(num_samples, dtype=bool) + chain = Chain(samples=samples, divergences=divergences, acceptance=0.) + if save_intermediates: + _qp_proto = QP(position_proto, position_proto) + _acc_rej_proto = AcceptedAndRejected( + _qp_proto, _qp_proto, True, True + ) + trees = tree_util.tree_map( + lambda leaf: jnp. + zeros_like(leaf, shape=(num_samples, ) + jnp.shape(leaf)), + _acc_rej_proto + ) + chain = chain._replace(trees=trees) + + return chain + + @staticmethod + def update_chain( + chain: Chain, idx: Union[jnp.ndarray, int], acc_rej: AcceptedAndRejected + ) -> Chain: + samples = tree_index_update( + chain.samples, idx, acc_rej.accepted_qp.position + ) + divergences = chain.divergences.at[idx].set(acc_rej.diverging) + acceptance = ( + chain.acceptance + (acc_rej.accepted - chain.acceptance) / + (idx + 1) + ) + chain = chain._replace( + samples=samples, divergences=divergences, acceptance=acceptance + ) + if chain.trees is not None: + trees = tree_index_update(chain.trees, idx, acc_rej) + chain = chain._replace(trees=trees) + + return chain diff --git a/src/re/kl.py b/src/re/kl.py new file mode 100644 index 0000000000000000000000000000000000000000..71a2eaf9c4082f6521fdfb6377f4a88b7411bdd4 --- /dev/null +++ b/src/re/kl.py @@ -0,0 +1,646 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union + +import jax +from jax import lax +from jax import random +from jax.tree_util import Partial, register_pytree_node_class + +from . import conjugate_gradient +from .forest_util import assert_arithmetics, map_forest, map_forest_mean, unstack +from .likelihood import Likelihood, StandardHamiltonian +from .sugar import random_like + +P = TypeVar("P") + + +def sample_likelihood(likelihood: Likelihood, primals, key): + white_sample = random_like(key, likelihood.left_sqrt_metric_tangents_shape) + return likelihood.left_sqrt_metric(primals, white_sample) + + +def cond_raise(condition, exception): + from jax.experimental.host_callback import call + + def maybe_raise(condition): + if condition: + raise exception + + call(maybe_raise, condition, result_shape=None) + + +def _sample_standard_hamiltonian( + hamiltonian: StandardHamiltonian, + primals, + key, + from_inverse: bool, + cg: Callable = conjugate_gradient.static_cg, + cg_name: Optional[str] = None, + cg_kwargs: Optional[dict] = None, +): + if not isinstance(hamiltonian, StandardHamiltonian): + te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'" + raise TypeError(te) + cg_kwargs = cg_kwargs if cg_kwargs is not None else {} + + subkey_nll, subkey_prr = random.split(key, 2) + nll_smpl = sample_likelihood( + hamiltonian.likelihood, primals, key=subkey_nll + ) + prr_inv_metric_smpl = random_like(key=subkey_prr, primals=primals) + # One may transform any metric sample to a sample of the inverse + # metric by simply applying the inverse metric to it + prr_smpl = prr_inv_metric_smpl + # Note, we can sample antithetically by swapping the global sign of + # the metric sample below (which corresponds to mirroring the final + # sample) and additionally by swapping the relative sign between + # the prior and the likelihood sample. The first technique is + # computationally cheap and empirically known to improve stability. + # The latter technique requires an additional inversion and its + # impact on stability is still unknown. + # TODO: investigate the impact of sampling the prior and likelihood + # antithetically. + met_smpl = nll_smpl + prr_smpl + if from_inverse: + inv_metric_at_p = partial( + cg, Partial(hamiltonian.metric, primals), **{ + "name": cg_name, + **cg_kwargs + } + ) + signal_smpl, info = inv_metric_at_p(met_smpl, x0=prr_inv_metric_smpl) + cond_raise( + (info is not None) & (info < 0), + ValueError("conjugate gradient failed") + ) + return signal_smpl, met_smpl + else: + return None, met_smpl + + +def sample_standard_hamiltonian( + hamiltonian: StandardHamiltonian, primals, *args, **kwargs +): + r"""Draws a sample of which the covariance is the metric or the inverse + metric of the Hamiltonian. + + To sample from the inverse metric, we need to be able to draw samples + which have the metric as covariance structure and we need to be able to + apply the inverse metric. The first part is trivial since we can use + the left square root of the metric :math:`L` associated with every + likelihood: + + .. math:: + + \tilde{d} \leftarrow \mathcal{G}(0,\mathbb{1}) \\ + t = L \tilde{d} + + with :math:`t` now having a covariance structure of + + .. math:: + <t t^\dagger> = L <\tilde{d} \tilde{d}^\dagger> L^\dagger = M . + + We now need to apply the inverse metric in order to transform the + sample to an inverse sample. We can do so using the conjugate gradient + algorithm which yields the solution to :math:`M s = t`, i.e. applies the + inverse of :math:`M` to :math:`t`: + + .. math:: + + M &s = t \\ + &s = M^{-1} t = cg(M, t) . + + Parameters + ---------- + hamiltonian: + Hamiltonian with standard prior from which to draw samples. + primals : tree-like structure + Position at which to draw samples. + key : tuple, list or jnp.ndarray of uint32 of length two + Random key with which to generate random variables in data domain. + cg : callable, optional + Implementation of the conjugate gradient algorithm and used to + apply the inverse of the metric. + cg_kwargs : dict, optional + Additional keyword arguments passed on to `cg`. + + Returns + ------- + sample : tree-like structure + Sample of which the covariance is the inverse metric. + + See also + -------- + `Metric Gaussian Variational Inference`, Jakob Knollmüller, + Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ + """ + assert_arithmetics(primals) + inv_met_smpl, _ = _sample_standard_hamiltonian( + hamiltonian, primals, *args, from_inverse=True, **kwargs + ) + return inv_met_smpl + + +def geometrically_sample_standard_hamiltonian( + hamiltonian: StandardHamiltonian, + primals, + key, + mirror_linear_sample: bool, + linear_sampling_cg: Callable = conjugate_gradient.static_cg, + linear_sampling_name: Optional[str] = None, + linear_sampling_kwargs: Optional[dict] = None, + non_linear_sampling_method: str = "NewtonCG", + non_linear_sampling_name: Optional[str] = None, + non_linear_sampling_kwargs: Optional[dict] = None, +): + r"""Draws a sample which follows a standard normal distribution in the + canonical coordinate system of the Riemannian manifold associated with the + metric of the other distribution. The coordinate transformation is + approximated by expanding around a given point `primals`. + + Parameters + ---------- + hamiltonian: + Hamiltonian with standard prior from which to draw samples. + primals : tree-like structure + Position at which to draw samples. + key : tuple, list or jnp.ndarray of uint32 of length two + Random key with which to generate random variables in data domain. + linear_sampling_cg : callable + Implementation of the conjugate gradient algorithm and used to + apply the inverse of the metric. + linear_sampling_kwargs : dict + Additional keyword arguments passed on to `cg`. + non_linear_sampling_kwargs : dict + Additional keyword arguments passed on to the minimzer of the + non-linear potential. + + Returns + ------- + sample : tree-like structure + Sample of which the covariance is the inverse metric. + + See also + -------- + `Geometric Variational Inference`, Philipp Frank, Reimar Leike, + Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_ + `<https://doi.org/10.3390/e23070853>`_ + """ + if not isinstance(hamiltonian, StandardHamiltonian): + te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'" + raise TypeError(te) + assert_arithmetics(primals) + from .energy_operators import Gaussian + from .optimize import minimize + + inv_met_smpl, met_smpl = _sample_standard_hamiltonian( + hamiltonian, + primals, + key=key, + from_inverse=True, + cg=linear_sampling_cg, + cg_name=linear_sampling_name, + cg_kwargs=linear_sampling_kwargs + ) + + if isinstance(non_linear_sampling_kwargs, dict): + nls_kwargs = non_linear_sampling_kwargs + elif non_linear_sampling_kwargs is None: + nls_kwargs = {} + else: + te = ( + "`non_linear_sampling_kwargs` of invalid type" + "{type(non_linear_sampling_kwargs)}" + ) + raise TypeError(te) + nls_kwargs = {"name": non_linear_sampling_name, **nls_kwargs} + if "hessp" in nls_kwargs: + ve = "setting the hessian for an unknown function is invalid" + raise ValueError(ve) + # Abort early if non-linear sampling is effectively disabled + if nls_kwargs.get("maxiter") == 0: + if mirror_linear_sample: + return (inv_met_smpl, -inv_met_smpl) + return (inv_met_smpl, ) + + lh_trafo_at_p = hamiltonian.likelihood.transformation(primals) + + def draw_non_linear_sample(lh, met_smpl, inv_met_smpl): + x0 = primals + inv_met_smpl + + def g(x): + return x - primals + lh.left_sqrt_metric( + primals, + lh.transformation(x) - lh_trafo_at_p + ) + + r2_half = Gaussian(met_smpl) @ g # (g - met_smpl)**2 / 2 + + options = nls_kwargs.copy() + options["hessp"] = r2_half.metric + + opt_state = minimize( + r2_half, x0=x0, method=non_linear_sampling_method, options=options + ) + + return opt_state.x, opt_state.status + + smpl1, smpl1_status = draw_non_linear_sample( + hamiltonian.likelihood, met_smpl, inv_met_smpl + ) + cond_raise( + (smpl1_status is not None) & (smpl1_status < 0), + ValueError("S: failed to invert map") + ) + if not mirror_linear_sample: + return (smpl1 - primals, ) + smpl2, smpl2_status = draw_non_linear_sample( + hamiltonian.likelihood, -met_smpl, -inv_met_smpl + ) + cond_raise( + (smpl2_status is not None) & (smpl2_status < 0), + ValueError("S: failed to invert map") + ) + return (smpl1 - primals, smpl2 - primals) + + +@register_pytree_node_class +class SampleIter(): + """Storage class for samples with some convenience methods for applying + operators of them + + This class is used to store samples for the Variational Inference schemes + MGVI and geoVI where samples are defined relative to some expansion point + (a.k.a. latent mean or offset). + + See also + -------- + `Geometric Variational Inference`, Philipp Frank, Reimar Leike, + Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_ + `<https://doi.org/10.3390/e23070853>`_ + + `Metric Gaussian Variational Inference`, Jakob Knollmüller, + Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ + """ + def __init__( + self, + *, + mean: P = None, + samples: Sequence[P], + linearly_mirror_samples: bool = False, + ): + self._samples = tuple(samples) + self._mean = mean + + self._n_samples = len(self._samples) + if linearly_mirror_samples == True: + self._n_samples *= 2 + self._linearly_mirror_samples = linearly_mirror_samples + # TODO/IDEA: Implement a transposed SampleIter object (SampleStack) + # akin to `vmap_forest_mean` + + def __iter__(self): + for s in self._samples: + yield self._mean + s if self._mean is not None else s + if self._linearly_mirror_samples: + yield self._mean - s if self._mean is not None else -s + + def __len__(self): + return self._n_samples + + @property + def n_samples(self): + """Total number of samples, equivalent to the length of the object""" + return len(self) + + def at(self, mean): + """Updates the offset (usually the latent mean) of all samples""" + return SampleIter( + mean=mean, + samples=self._samples, + linearly_mirror_samples=self._linearly_mirror_samples + ) + + @property + def first(self): + """Convenience method to easily retrieve a sample (the first one)""" + if self._mean is not None: + return self._mean + self._samples[0] + return self._samples[0] + + def apply(self, call: Callable, *args, **kwargs): + """Applies an operator over all samples, yielding a list of outputs + + Internally, the call is `vmap`-ed over the samples for additional + efficiency. + """ + if set(kwargs.keys()) | {"in_axes"} != {"in_axes"}: + raise ValueError(f"invalid keyword arguments {kwargs}") + + # TODO: vmap is significantly slower than looping over the samples + # for an extremely high dimensional problem. + in_axes = kwargs.get("in_axes", (0, )) + return map_forest(call, in_axes=in_axes)(tuple(self), *args) + + def mean(self, call: Callable, *args, **kwargs): + """Applies an operator over all samples and averages the results + + Internally, the call is `vmap`-ed over the samples for additional + efficiency. + """ + if set(kwargs.keys()) | {"in_axes"} != {"in_axes"}: + raise ValueError(f"invalid keyword arguments {kwargs}") + + # TODO: vmap is significantly slower than looping over the samples + # for an extremely high dimensional problem. + in_axes = kwargs.get("in_axes", (0, )) + return map_forest_mean(call, in_axes=in_axes)(tuple(self), *args) + + def tree_flatten(self): + return ((self._mean, self._samples), (self._linearly_mirror_samples, )) + + @classmethod + def tree_unflatten(cls, aux, children): + if len(aux) != 1 or len(children) != 2: + raise ValueError() + return cls( + mean=children[0], + samples=children[1], + linearly_mirror_samples=aux[0] + ) + + +def MetricKL( + hamiltonian: StandardHamiltonian, + primals, + n_samples: int, + key, + mirror_samples: bool = True, + sample_mapping: Union[str, Callable] = 'lax', + linear_sampling_cg: Callable = conjugate_gradient.static_cg, + linear_sampling_name: Optional[str] = None, + linear_sampling_kwargs: Optional[dict] = None, +) -> SampleIter: + """Provides the sampled Kullback-Leibler divergence between a distribution + and a Metric Gaussian. + + A Metric Gaussian is used to approximate another probability distribution. + It is a Gaussian distribution that uses the Fisher information metric of + the other distribution at the location of its mean to approximate the + variance. In order to infer the mean, a stochastic estimate of the + Kullback-Leibler divergence is minimized. This estimate is obtained by + sampling the Metric Gaussian at the current mean. During minimization these + samples are kept constant and only the mean is updated. Due to the + typically nonlinear structure of the true distribution these samples have + to be updated eventually by re-instantiating the Metric Gaussian again. For + the true probability distribution the standard parametrization is assumed. + + Parameters + ---------- + + hamiltonian : :class:`nifty8.src.re.likelihood.StandardHamiltonian` + Hamiltonian of the approximated probability distribution. + primals : :class:`nifty8.re.field.Field` + Expansion point of the coordinate transformation. + n_samples : integer + Number of samples used to stochastically estimate the KL. + key : DeviceArray + A PRNG-key. + mirror_samples : boolean + Whether the mirrored version of the drawn samples are also used. + If true, the number of used samples doubles. + Mirroring samples stabilizes the KL estimate as extreme + sample variation is counterbalanced. + Default is True. + sample_mapping : string, callable + Can be either a string-key to a mapping function or a mapping function + itself. The function is used to map the drawing of samples. Possible + string-keys are: + + keys - functions + ------------------------------------- + 'pmap' or 'p' - jax.pmap + 'lax.map' or 'lax' - jax.lax.map + + In case sample_mapping is passed as a function, it should produce a + mapped function f_mapped of a general function f as: `f_mapped = + sample_mapping(f)` + linear_sampling_cg : callable + Implementation of the conjugate gradient algorithm and used to + apply the inverse of the metric. + linear_sampling_name : string, optional + 'name'-keyword-argument passed to `linear_sampling_cg`. + linear_sampling_kwargs : dict, optional + Additional keyword arguments passed on to `linear_sampling_cg`. + + See also + -------- + `Metric Gaussian Variational Inference`, Jakob Knollmüller, + Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ + """ + if not isinstance(hamiltonian, StandardHamiltonian): + te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'" + raise TypeError(te) + assert_arithmetics(primals) + + draw = partial( + sample_standard_hamiltonian, + hamiltonian=hamiltonian, + primals=primals, + cg=linear_sampling_cg, + cg_name=linear_sampling_name, + cg_kwargs=linear_sampling_kwargs + ) + subkeys = random.split(key, n_samples) + if isinstance(sample_mapping, str): + if sample_mapping == 'pmap' or sample_mapping == 'p': + sample_mapping = jax.pmap + elif sample_mapping == 'lax.map' or sample_mapping == 'lax': + sample_mapping = partial(partial, lax.map) + else: + ve = ( + f"{sample_mapping} is not an accepted key to a mapping function" + "; please pass function directly" + ) + raise ValueError(ve) + + elif not callable(sample_mapping): + te = ( + f"invalid `sample_mapping` of type {type(sample_mapping)!r}" + "; expected string or callable" + ) + raise TypeError(te) + + samples_stack = sample_mapping(lambda k: draw(key=k))(subkeys) + + return SampleIter( + mean=primals, + samples=unstack(samples_stack), + linearly_mirror_samples=mirror_samples + ) + + +def GeoMetricKL( + hamiltonian: StandardHamiltonian, + primals, + n_samples: int, + key, + mirror_samples: bool = True, + linear_sampling_cg: Callable = conjugate_gradient.static_cg, + linear_sampling_name: Optional[str] = None, + linear_sampling_kwargs: Optional[dict] = None, + non_linear_sampling_method: str = "NewtonCG", + non_linear_sampling_name: Optional[str] = None, + non_linear_sampling_kwargs: Optional[dict] = None, +) -> SampleIter: + """Provides the sampled Kullback-Leibler used in geometric Variational + Inference (geoVI). + + In geoVI a probability distribution is approximated with a standard normal + distribution in the canonical coordinate system of the Riemannian manifold + associated with the metric of the other distribution. The coordinate + transformation is approximated by expanding around a point. In order to + infer the expansion point, a stochastic estimate of the Kullback-Leibler + divergence is minimized. This estimate is obtained by sampling from the + approximation using the current expansion point. During minimization these + samples are kept constant and only the expansion point is updated. Due to + the typically nonlinear structure of the true distribution these samples + have to be updated eventually by re-instantiating the geometric Gaussian + again. For the true probability distribution the standard parametrization + is assumed. + + See also + -------- + `Geometric Variational Inference`, Philipp Frank, Reimar Leike, + Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_ + `<https://doi.org/10.3390/e23070853>`_ + """ + if not isinstance(hamiltonian, StandardHamiltonian): + te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'" + raise TypeError(te) + assert_arithmetics(primals) + + draw = partial( + geometrically_sample_standard_hamiltonian, + hamiltonian=hamiltonian, + primals=primals, + mirror_linear_sample=mirror_samples, + linear_sampling_cg=linear_sampling_cg, + linear_sampling_name=linear_sampling_name, + linear_sampling_kwargs=linear_sampling_kwargs, + non_linear_sampling_method=non_linear_sampling_method, + non_linear_sampling_name=non_linear_sampling_name, + non_linear_sampling_kwargs=non_linear_sampling_kwargs + ) + subkeys = random.split(key, n_samples) + # TODO: Make `geometrically_sample_standard_hamiltonian` jit-able + # samples_stack = lax.map(lambda k: draw(key=k), subkeys) + # Unpack tuple of samples + # samples_stack = tree_map( + # lambda a: a.reshape((-1, ) + a.shape[2:]), samples_stack + # ) + # samples = unstack(samples_stack) + samples = tuple(s for ss in map(lambda k: draw(key=k), subkeys) for s in ss) + + return SampleIter( + mean=primals, samples=samples, linearly_mirror_samples=False + ) + + +def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs): + """Thin wrapper around `value_and_grad` and the provided sample mapping + function, e.g. `vmap` to apply a cost function to a mean and a list of + residual samples. + + Parameters + ---------- + + ham : :class:`nifty8.src.re.likelihood.StandardHamiltonian` + Hamiltonian of the approximated probability distribution, + of which the mean value and the mean gradient are to be computed. + sample_mapping : string, callable + Can be either a string-key to a mapping function or a mapping function + itself. The function is used to map the drawing of samples. Possible + string-keys are: + + keys - functions + ------------------------------------- + 'vmap' or 'v' - jax.vmap + 'pmap' or 'p' - jax.pmap + 'lax.map' or 'lax' - jax.lax.map + + In case sample_mapping is passed as a function, it should produce a + mapped function f_mapped of a general function f as: `f_mapped = + sample_mapping(f)` + """ + from jax import value_and_grad + vg = value_and_grad(ham, *args, **kwargs) + + def mean_vg( + primals: P, + primals_samples: Union[None, Sequence[P], SampleIter] = None, + **primals_kw + ) -> Tuple[Any, P]: + ham_vg = partial(vg, **primals_kw) + if primals_samples is None: + return ham_vg(primals) + + if not isinstance(primals_samples, SampleIter): + primals_samples = SampleIter(samples=primals_samples) + return map_forest_mean(ham_vg, mapping=sample_mapping, in_axes=(0, ))( + tuple(primals_samples.at(primals)) + ) + + return mean_vg + + +def mean_hessp(ham: Callable, *args, **kwargs): + """Thin wrapper around `jvp`, `grad` and `vmap` to apply a binary method to + a primal mean, a tangent and a list of residual primal samples. + """ + from jax import jvp, grad + jac = grad(ham, *args, **kwargs) + + def mean_hp( + primals: P, + tangents: Any, + primals_samples: Union[None, Sequence[P], SampleIter] = None, + **primals_kw + ) -> P: + if primals_samples is None: + _, hp = jvp(partial(jac, **primals_kw), (primals, ), (tangents, )) + return hp + + if not isinstance(primals_samples, SampleIter): + primals_samples = SampleIter(samples=primals_samples) + return map_forest_mean( + partial(mean_hp, primals_samples=None, **primals_kw), + in_axes=(0, None) + )(tuple(primals_samples.at(primals)), tangents) + + return mean_hp + + +def mean_metric(metric: Callable): + """Thin wrapper around `vmap` to apply a binary method to a primal mean, a + tangent and a list of residual primal samples. + """ + def mean_met( + primals: P, + tangents: Any, + primals_samples: Union[None, Sequence[P], SampleIter] = None, + **primals_kw + ) -> P: + if primals_samples is None: + return metric(primals, tangents, **primals_kw) + + if not isinstance(primals_samples, SampleIter): + primals_samples = SampleIter(samples=primals_samples) + return map_forest_mean( + partial(metric, **primals_kw), in_axes=(0, None) + )(tuple(primals_samples.at(primals)), tangents) + + return mean_met diff --git a/src/re/lanczos.py b/src/re/lanczos.py new file mode 100644 index 0000000000000000000000000000000000000000..f684246fe3a529166a4036415325394e6762898b --- /dev/null +++ b/src/re/lanczos.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +from typing import Callable, Optional, Union + +import jax +from jax import numpy as jnp +from jax import random + +from .forest_util import ShapeWithDtype + + +def lanczos_tridiag( + mat: Callable, shape_dtype_struct: ShapeWithDtype, order: int, + key: jnp.ndarray +): + """Compute the Lanczos decomposition into a tri-diagonal matrix and its + corresponding orthonormal projection matrix. + """ + tridiag = jnp.zeros((order, order), dtype=shape_dtype_struct.dtype) + vecs = jnp.zeros( + (order, ) + shape_dtype_struct.shape, dtype=shape_dtype_struct.dtype + ) + + v = random.normal(key, shape=shape_dtype_struct.shape) + v = v / jnp.linalg.norm(v) + vecs = vecs.at[0].set(v) + + # Zeroth iteration + w = mat(v) + if w.shape != shape_dtype_struct.shape: + ve = f"shape of `mat(v)` {w.shape!r} incompatible with {shape_dtype_struct}" + raise ValueError(ve) + alpha = jnp.dot(w, v) + tridiag = tridiag.at[(0, 0)].set(alpha) + w -= alpha * v + beta = jnp.linalg.norm(w) + + tridiag = tridiag.at[(0, 1)].set(beta) + tridiag = tridiag.at[(1, 0)].set(beta) + vecs = vecs.at[1].set(w / beta) + + def reortho_step(j, state): + vecs, w = state + + tau = vecs[j, :].reshape(shape_dtype_struct.shape) + coeff = jnp.dot(w, tau) + w -= coeff * tau + return vecs, w + + def lanczos_step(i, state): + tridiag, vecs, beta = state + + v = vecs[i, :].reshape(shape_dtype_struct.shape) + v_old = vecs[i - 1, :].reshape(shape_dtype_struct.shape) + + w = mat(v) - beta * v_old + alpha = jnp.dot(w, v) + tridiag = tridiag.at[(i, i)].set(alpha) + w -= alpha * v + + # Full reorthogonalization + vecs, w = jax.lax.fori_loop(0, i, reortho_step, (vecs, w)) + + # TODO: Raise if lanczos vectors are independent i.e. `beta` small? + beta = jnp.linalg.norm(w) + + tridiag = tridiag.at[(i, i + 1)].set(beta) + tridiag = tridiag.at[(i + 1, i)].set(beta) + vecs = vecs.at[i + 1].set(w / beta) + + return tridiag, vecs, beta + + tridiag, vecs, beta = jax.lax.fori_loop( + 1, order - 1, lanczos_step, (tridiag, vecs, beta) + ) + + # Final tridiag value and reorthogonalization + v = vecs[order - 1, :].reshape(shape_dtype_struct.shape) + v_old = vecs[order - 2, :].reshape(shape_dtype_struct.shape) + w = mat(v) - beta * v_old + alpha = jnp.dot(w, v) + tridiag = tridiag.at[(order - 1, order - 1)].set(alpha) + w -= alpha * v + vecs, w = jax.lax.fori_loop(0, order - 1, reortho_step, (vecs, w)) + + return (tridiag, vecs) + + +def stochastic_logdet_from_lanczos( + tridiag_stack: jnp.ndarray, matrix_shape0: int, func: Callable = jnp.log +): + """Computes a stochastic estimate of the log-determinate of a matrix using + its Lanczos decomposition. + + Implemented via the stoachstic Lanczos quadrature. + """ + eig_vals, eig_vecs = jnp.linalg.eigh(tridiag_stack) + # TODO: Mask Eigenvalues <= 0? + + num_random_probes = tridiag_stack.shape[0] + + eig_ves_first_component = eig_vecs[..., 0, :] + func_of_eig_vals = func(eig_vals) + + dot_products = jnp.sum(eig_ves_first_component**2 * func_of_eig_vals) + return matrix_shape0 / float(num_random_probes) * dot_products + + +def stochastic_lq_logdet( + mat: Union[jnp.ndarray, Callable], + order: int, + n_samples: int, + key: Union[int, jnp.ndarray], + *, + shape0: Optional[int] = None, + dtype=None +): + """Computes a stochastic estimate of the log-determinate of a matrix using + the stoachstic Lanczos quadrature algorithm. + """ + shape0 = shape0 if shape0 is not None else mat.shape[0] + mat = mat.__matmul__ if not hasattr(mat, "__call__") else mat + if not isinstance(key, jnp.ndarray): + key = random.PRNGKey(key) + keys = random.split(key, n_samples) + + lanczos = partial(lanczos_tridiag, mat, ShapeWithDtype(shape0, dtype)) + tridiags, _ = jax.vmap(lanczos, in_axes=(None, 0), + out_axes=(0, 0))(order, keys) + return stochastic_logdet_from_lanczos(tridiags, shape0) diff --git a/src/re/likelihood.py b/src/re/likelihood.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a9621e33486dde2b2060126f3f5f359da0b25c --- /dev/null +++ b/src/re/likelihood.py @@ -0,0 +1,390 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from typing import Any, Callable, Optional, TypeVar, Union + +from jax import numpy as jnp +from jax import linear_transpose, linearize, vjp +from jax.tree_util import Partial, tree_leaves + +from .forest_util import ShapeWithDtype, split +from .sugar import is1d, isiterable, sum_of_squares, doc_from + +Q = TypeVar("Q") + + +class Likelihood(): + """Storage class for keeping track of the energy, the associated + left-square-root of the metric and the metric. + """ + def __init__( + self, + energy: Callable[..., Union[jnp.ndarray, float]], + transformation: Optional[Callable[[Q], Any]] = None, + left_sqrt_metric: Optional[Callable[[Q, Q], Any]] = None, + metric: Optional[Callable[[Q, Q], Any]] = None, + lsm_tangents_shape=None + ): + """Instantiates a new likelihood. + + Parameters + ---------- + energy : callable + Function evaluating the negative log-likelihood. + transformation : callable, optional + Function evaluating the geometric transformation of the likelihood. + left_sqrt_metric : callable, optional + Function applying the left-square-root of the metric. + metric : callable, optional + Function applying the metric. + lsm_tangents_shape : tree-like structure of ShapeWithDtype, optional + Structure of the data space. + """ + self._hamiltonian = energy + self._transformation = transformation + self._left_sqrt_metric = left_sqrt_metric + self._metric = metric + + if lsm_tangents_shape is not None: + leaves = tree_leaves(lsm_tangents_shape) + if not all( + hasattr(e, "shape") and hasattr(e, "dtype") for e in leaves + ): + if is1d(lsm_tangents_shape + ) or not isiterable(lsm_tangents_shape): + lsm_tangents_shape = ShapeWithDtype(lsm_tangents_shape) + else: + te = "`lsm_tangent_shapes` of invalid type" + raise TypeError(te) + self._lsm_tan_shp = lsm_tangents_shape + + def __call__(self, primals, **primals_kw): + """Convenience method to access the `energy` method of this instance. + """ + return self.energy(primals, **primals_kw) + + def energy(self, primals, **primals_kw): + """Applies the energy to `primals`. + + Parameters + ---------- + primals : tree-like structure + Position at which to evaluate the energy. + **primals_kw : Any + Additional arguments passed on to the energy. + + Returns + ------- + energy : float + Energy at the position `primals`. + """ + return self._hamiltonian(primals, **primals_kw) + + def metric(self, primals, tangents, **primals_kw): + """Applies the metric at `primals` to `tangents`. + + Parameters + ---------- + primals : tree-like structure + Position at which to evaluate the metric. + tangents : tree-like structure + Instance to which to apply the metric. + **primals_kw : Any + Additional arguments passed on to the metric. + + Returns + ------- + naturally_curved : tree-like structure + Tree-like structure of the same type as primals to which the metric + has been applied to. + """ + if self._metric is None: + from jax import linear_transpose + + lsm_at_p = Partial(self.left_sqrt_metric, primals, **primals_kw) + rsm_at_p = linear_transpose( + lsm_at_p, self.left_sqrt_metric_tangents_shape + ) + res = lsm_at_p(*rsm_at_p(tangents)) + return res + return self._metric(primals, tangents, **primals_kw) + + def left_sqrt_metric(self, primals, tangents, **primals_kw): + """Applies the left-square-root of the metric at `primals` to + `tangents`. + + Parameters + ---------- + primals : tree-like structure + Position at which to evaluate the metric. + tangents : tree-like structure + Instance to which to apply the metric. + **primals_kw : Any + Additional arguments passed on to the LSM. + + Returns + ------- + metric_sample : tree-like structure + Tree-like structure of the same type as primals to which the + left-square-root of the metric has been applied to. + """ + if self._left_sqrt_metric is None: + _, bwd = vjp(Partial(self.transformation, **primals_kw), primals) + res = bwd(tangents) + return res[0] + return self._left_sqrt_metric(primals, tangents, **primals_kw) + + def transformation(self, primals, **primals_kw): + """Applies the coordinate transformation that maps into a coordinate + system in which the metric of the likelihood is the Euclidean metric. + + Parameters + ---------- + primals : tree-like structure + Position at which to transform. + **primals_kw : Any + Additional arguments passed on to the transformation. + + Returns + ------- + transformed_sample : tree-like structure + Structure of the same type as primals to which the geometric + transformation has been applied to. + """ + if self._transformation is None: + nie = "`transformation` is not implemented" + raise NotImplementedError(nie) + return self._transformation(primals, **primals_kw) + + @property + def left_sqrt_metric_tangents_shape(self): + """Retrieves the shape of the tangent domain of the + left-square-root of the metric. + """ + return self._lsm_tan_shp + + @property + def lsm_tangents_shape(self): + """Alias for `left_sqrt_metric_tangents_shape`.""" + return self.left_sqrt_metric_tangents_shape + + def new( + self, energy: Callable, transformation: Optional[Callable], + left_sqrt_metric: Optional[Callable], metric: Optional[Callable] + ): + """Instantiates a new likelihood with the same `lsm_tangents_shape`. + + Parameters + ---------- + energy : callable + Function evaluating the negative log-likelihood. + transformation : callable, optional + Function evaluating the geometric transformation of the + log-likelihood. + left_sqrt_metric : callable, optional + Function applying the left-square-root of the metric. + metric : callable, optional + Function applying the metric. + """ + return Likelihood( + energy, + transformation=transformation, + left_sqrt_metric=left_sqrt_metric, + metric=metric, + lsm_tangents_shape=self._lsm_tan_shp + ) + + def jit(self, **kwargs): + """Returns a new likelihood with jit-compiled energy, left-square-root + of metric and metric. + """ + from jax import jit + + if self._transformation is not None: + j_trafo = jit(self.transformation, **kwargs) + j_lsm = jit(self.left_sqrt_metric, **kwargs) + j_m = jit(self.metric, **kwargs) + elif self._left_sqrt_metric is not None: + j_trafo = None + j_lsm = jit(self.left_sqrt_metric, **kwargs) + j_m = jit(self.metric, **kwargs) + elif self._metric is not None: + j_trafo, j_lsm = None, None + j_m = jit(self.metric, **kwargs) + else: + j_trafo, j_lsm, j_m = None, None, None + + return self.new( + jit(self._hamiltonian, **kwargs), + transformation=j_trafo, + left_sqrt_metric=j_lsm, + metric=j_m + ) + + def __matmul__(self, f: Callable): + return self.matmul(f, left_argnames=(), right_argnames=None) + + def matmul(self, f: Callable, left_argnames=(), right_argnames=None): + """Amend the function `f` to the right of the likelihood. + + Parameters + ---------- + f : Callable + Function which to amend to the likelihood. + left_argnames : tuple or None + Keys of the keyword arguments of the joined likelihood which + to pass to the original likelihood. Passing `None` indicates + the intent to absorb everything not explicitly absorbed by + the other call. + right_argnames : tuple or None + Keys of the keyword arguments of the joined likelihood which + to pass to the amended function. Passing `None` indicates + the intent to absorb everything not explicitly absorbed by + the other call. + + Returns + ------- + lh : Likelihood + """ + if (left_argnames is None and right_argnames is None) or \ + (left_argnames is not None and right_argnames is not None): + ve = "only one of `left_argnames` and `right_argnames` can be (not) `None`" + raise ValueError(ve) + + def split_kwargs(**kwargs): + if left_argnames is None: # right_argnames must be not None + right_kw, left_kw = split(kwargs, right_argnames) + else: # right_argnames must be None + left_kw, right_kw = split(kwargs, left_argnames) + return left_kw, right_kw + + def energy_at_f(primals, **primals_kw): + kw_l, kw_r = split_kwargs(**primals_kw) + return self.energy(f(primals, **kw_r), **kw_l) + + def transformation_at_f(primals, **primals_kw): + kw_l, kw_r = split_kwargs(**primals_kw) + return self.transformation(f(primals, **kw_r), **kw_l) + + def metric_at_f(primals, tangents, **primals_kw): + kw_l, kw_r = split_kwargs(**primals_kw) + # Note, judging by a simple benchmark on a large problem, + # transposing the JVP seems faster than computing the VJP again. On + # small problems there seems to be no measurable difference. + y, fwd = linearize(Partial(f, **kw_r), primals) + bwd = linear_transpose(fwd, primals) + return bwd(self.metric(y, fwd(tangents), **kw_l))[0] + + def left_sqrt_metric_at_f(primals, tangents, **primals_kw): + kw_l, kw_r = split_kwargs(**primals_kw) + y, bwd = vjp(Partial(f, **kw_r), primals) + left_at_fp = self.left_sqrt_metric(y, tangents, **kw_l) + return bwd(left_at_fp)[0] + + return self.new( + energy_at_f, + transformation=transformation_at_f, + left_sqrt_metric=left_sqrt_metric_at_f, + metric=metric_at_f + ) + + def __add__(self, other): + if not isinstance(other, Likelihood): + te = ( + "object which to add to this instance is of invalid type" + f" {type(other)!r}" + ) + raise TypeError(te) + + def joined_hamiltonian(p, **pkw): + return self.energy(p, **pkw) + other.energy(p, **pkw) + + def joined_metric(p, t, **pkw): + return self.metric(p, t, **pkw) + other.metric(p, t, **pkw) + + joined_tangents_shape = { + "lh_left": self._lsm_tan_shp, + "lh_right": other._lsm_tan_shp + } + + def joined_transformation(p, **pkw): + from warnings import warn + + # FIXME + warn("adding transformations is untested", UserWarning) + return { + "lh_left": self.transformation(p, **pkw), + "lh_right": other.transformation(p, **pkw) + } + + def joined_left_sqrt_metric(p, t, **pkw): + return self.left_sqrt_metric( + p, t["lh_left"], **pkw + ) + other.left_sqrt_metric(p, t["lh_right"], **pkw) + + return Likelihood( + joined_hamiltonian, + transformation=joined_transformation, + left_sqrt_metric=joined_left_sqrt_metric, + metric=joined_metric, + lsm_tangents_shape=joined_tangents_shape + ) + + +class StandardHamiltonian(): + """Joined object storage composed of a user-defined likelihood and a + standard normal likelihood as prior. + """ + def __init__( + self, + likelihood: Likelihood, + _compile_joined: bool = False, + _compile_kwargs: dict = {} + ): + """Instantiates a new standardized Hamiltonian, i.e. a likelihood + joined with a standard normal prior. + + Parameters + ---------- + likelihood : Likelihood + Energy, left-square-root of metric and metric of the likelihood. + """ + self._lh = likelihood + + def joined_hamiltonian(primals, **primals_kw): + # Assume the first primals to be the parameters + return self._lh(primals, ** + primals_kw) + 0.5 * sum_of_squares(primals) + + def joined_metric(primals, tangents, **primals_kw): + return self._lh.metric(primals, tangents, **primals_kw) + tangents + + if _compile_joined: + from jax import jit + joined_hamiltonian = jit(joined_hamiltonian, **_compile_kwargs) + joined_metric = jit(joined_metric, **_compile_kwargs) + self._hamiltonian = joined_hamiltonian + self._metric = joined_metric + + @doc_from(Likelihood.__call__) + def __call__(self, primals, **primals_kw): + return self.energy(primals, **primals_kw) + + @doc_from(Likelihood.energy) + def energy(self, primals, **primals_kw): + return self._hamiltonian(primals, **primals_kw) + + @doc_from(Likelihood.metric) + def metric(self, primals, tangents, **primals_kw): + return self._metric(primals, tangents, **primals_kw) + + @property + def likelihood(self): + return self._lh + + def jit(self, **kwargs): + return StandardHamiltonian( + self.likelihood.jit(**kwargs), + _compile_joined=True, + _compile_kwargs=kwargs + ) diff --git a/src/re/optimize.py b/src/re/optimize.py new file mode 100644 index 0000000000000000000000000000000000000000..6699cf5da6c4df8cd42775bb628659466f2313af --- /dev/null +++ b/src/re/optimize.py @@ -0,0 +1,481 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +import sys +from datetime import datetime +from typing import Any, Callable, Dict, Mapping, NamedTuple, Optional, Tuple, Union + +from jax import lax +from jax import numpy as jnp +from jax.tree_util import Partial + +from . import conjugate_gradient +from .forest_util import assert_arithmetics, common_type, size, where +from .forest_util import norm as jft_norm +from .sugar import sum_of_squares + + +class OptimizeResults(NamedTuple): + """Object holding optimization results inspired by JAX and scipy. + + Attributes + ---------- + x : ndarray + The solution of the optimization. + success : bool + Whether or not the optimizer exited successfully. + status : int + Termination status of the optimizer. Its value depends on the + underlying solver. NOTE, in contrast to scipy there is no `message` for + details since strings are not statically memory bound. + fun, jac, hess: ndarray + Values of objective function, its Jacobian and its Hessian (if + available). The Hessians may be approximations, see the documentation + of the function in question. + hess_inv : object + Inverse of the objective function's Hessian; may be an approximation. + Not available for all solvers. + nfev, njev, nhev : int + Number of evaluations of the objective functions and of its + Jacobian and Hessian. + nit : int + Number of iterations performed by the optimizer. + """ + x: Any + success: Union[bool, jnp.ndarray] + status: Union[int, jnp.ndarray] + fun: Any + jac: Any + hess: Optional[jnp.ndarray] = None + hess_inv: Optional[jnp.ndarray] = None + nfev: Union[None, int, jnp.ndarray] = None + njev: Union[None, int, jnp.ndarray] = None + nhev: Union[None, int, jnp.ndarray] = None + nit: Union[None, int, jnp.ndarray] = None + # Trust-Region specific slots + trust_radius: Union[None, float, jnp.ndarray] = None + jac_magnitude: Union[None, float, jnp.ndarray] = None + good_approximation: Union[None, bool, jnp.ndarray] = None + + +def _prepare_vag_hessp(fun, jac, hessp, + fun_and_grad) -> Tuple[Callable, Callable]: + """Returns a tuple of functions for computing the value-and-gradient and + the Hessian-Vector-Product. + """ + from warnings import warn + + if fun_and_grad is None: + if fun is not None and jac is not None: + uw = "computing the function together with its gradient would be faster" + warn(uw, UserWarning) + + def fun_and_grad(x): + return (fun(x), jac(x)) + elif fun is not None: + from jax import value_and_grad + + fun_and_grad = value_and_grad(fun) + else: + ValueError("no function specified") + + if hessp is None: + from jax import grad, jvp + + jac = grad(fun) if jac is None else jac + + def hessp(primals, tangents): + return jvp(jac, (primals, ), (tangents, ))[1] + + return fun_and_grad, hessp + + +def newton_cg(fun=None, x0=None, *args, **kwargs): + """Minimize a scalar-valued function using the Newton-CG algorithm.""" + if x0 is not None: + assert_arithmetics(x0) + return _newton_cg(fun, x0, *args, **kwargs).x + + +def _newton_cg( + fun=None, + x0=None, + *, + miniter=None, + maxiter=None, + energy_reduction_factor=0.1, + old_fval=None, + absdelta=None, + norm_ord=None, + xtol=1e-5, + jac: Optional[Callable] = None, + fun_and_grad=None, + hessp=None, + cg=conjugate_gradient._cg, + name=None, + time_threshold=None, + cg_kwargs=None +): + norm_ord = 1 if norm_ord is None else norm_ord + miniter = 0 if miniter is None else miniter + maxiter = 200 if maxiter is None else maxiter + xtol = xtol * size(x0) + + pos = x0 + fun_and_grad, hessp = _prepare_vag_hessp( + fun, jac, hessp, fun_and_grad=fun_and_grad + ) + cg_kwargs = {} if cg_kwargs is None else cg_kwargs + cg_name = name + "CG" if name is not None else None + + energy, g = fun_and_grad(pos) + nfev, njev, nhev = 1, 1, 0 + if jnp.isnan(energy): + raise ValueError("energy is Nan") + status = -1 + i = 0 + for i in range(1, maxiter + 1): + # Newton approximates the potential up to second order. The CG energy + # (`0.5 * x.T @ A @ x - x.T @ b`) and the approximation to the true + # potential in Newton thus live on comparable energy scales. Hence, the + # energy in a Newton minimization can be used to set the CG energy + # convergence criterion. + if old_fval and energy_reduction_factor: + cg_absdelta = energy_reduction_factor * (old_fval - energy) + else: + cg_absdelta = None if absdelta is None else absdelta / 100. + mag_g = jft_norm(g, ord=cg_kwargs.get("norm_ord", 1), ravel=True) + cg_resnorm = jnp.minimum( + 0.5, jnp.sqrt(mag_g) + ) * mag_g # taken from SciPy + default_kwargs = { + "absdelta": cg_absdelta, + "resnorm": cg_resnorm, + "norm_ord": 1, + "_within_newton": True, # handle non-pos-def + "name": cg_name, + "time_threshold": time_threshold + } + cg_res = cg(Partial(hessp, pos), g, **{**default_kwargs, ** cg_kwargs}) + nat_g, info = cg_res.x, cg_res.info + nhev += cg_res.nfev + if info is not None and info < 0: + raise ValueError("conjugate gradient failed") + + naive_ls_it = 0 + dd = nat_g # negative descent direction + grad_scaling = 1. + ls_reset = False + for naive_ls_it in range(9): + new_pos = pos - grad_scaling * dd + new_energy, new_g = fun_and_grad(new_pos) + nfev, njev = nfev + 1, njev + 1 + if new_energy <= energy: + break + + grad_scaling /= 2 + if naive_ls_it == 5: + ls_reset = True + gam = float(sum_of_squares(g)) + curv = float(g.dot(hessp(pos, g))) + nhev += 1 + grad_scaling = 1. + dd = gam / curv * g + else: + grad_scaling = 0. + nm = "N" if name is None else name + msg = f"{nm}: WARNING: Energy would increase; aborting" + print(msg, file=sys.stderr) + status = -1 + break + + energy_diff = energy - new_energy + old_fval = energy + energy = new_energy + pos = new_pos + g = new_g + + descent_norm = grad_scaling * jft_norm(dd, ord=norm_ord, ravel=True) + if name is not None: + msg = ( + f"{name}: →:{grad_scaling} ↺:{ls_reset} #∇²:{nhev:02d}" + f" |↘|:{descent_norm:.6e} 🞋:{xtol:.6e}" + f"\n{name}: Iteration {i} ⛰:{energy:+.6e} Δ⛰:{energy_diff:.6e}" + + (f" 🞋:{absdelta:.6e}" if absdelta is not None else "") + ) + print(msg, file=sys.stderr) + if jnp.isnan(new_energy): + raise ValueError("energy is NaN") + min_cond = naive_ls_it < 2 and i > miniter + if absdelta is not None and 0. <= energy_diff < absdelta and min_cond: + status = 0 + break + if descent_norm <= xtol and i > miniter: + status = 0 + break + if time_threshold is not None and datetime.now() > time_threshold: + status = i + break + else: + status = i + nm = "N" if name is None else name + print(f"{nm}: Iteration Limit Reached", file=sys.stderr) + return OptimizeResults( + x=pos, + success=True, + status=status, + fun=energy, + jac=g, + nit=i, + nfev=nfev, + njev=njev, + nhev=nhev + ) + + +class _TrustRegionState(NamedTuple): + x: Any + converged: Union[bool, jnp.ndarray] + status: Union[int, jnp.ndarray] + fun: Any + jac: Any + nfev: Union[int, jnp.ndarray] + njev: Union[int, jnp.ndarray] + nhev: Union[int, jnp.ndarray] + nit: Union[int, jnp.ndarray] + trust_radius: Union[float, jnp.ndarray] + jac_magnitude: Union[float, jnp.ndarray] + old_fval: Union[float, jnp.ndarray] + + +def _trust_ncg( + fun=None, + x0=None, + *, + maxiter: Optional[int] = None, + energy_reduction_factor=0.1, + old_fval=jnp.nan, + absdelta=None, + gtol: float = 1e-4, + max_trust_radius: Union[float, jnp.ndarray] = 1000., + initial_trust_radius: Union[float, jnp.ndarray] = 1.0, + eta: Union[float, jnp.ndarray] = 0.15, + subproblem=conjugate_gradient._cg_steihaug_subproblem, + jac: Optional[Callable] = None, + hessp: Optional[Callable] = None, + fun_and_grad: Optional[Callable] = None, + subproblem_kwargs: Optional[Dict[str, Any]] = None, + name: Optional[str] = None +) -> OptimizeResults: + maxiter = 200 if maxiter is None else maxiter + + status = jnp.where(maxiter == 0, 1, 0) + + if not (0 <= eta < 0.25): + raise Exception("invalid acceptance stringency") + # Exception("gradient tolerance must be positive") + status = jnp.where(gtol < 0., -1, status) + # Exception("max trust radius must be positive") + status = jnp.where(max_trust_radius <= 0, -1, status) + # ValueError("initial trust radius must be positive") + status = jnp.where(initial_trust_radius <= 0, -1, status) + # ValueError("initial trust radius must be less than the max trust radius") + status = jnp.where(initial_trust_radius >= max_trust_radius, -1, status) + + common_dtp = common_type(x0) + eps = 6. * jnp.finfo( + common_dtp + ).eps # Inspired by SciPy's NewtonCG minimzer + + fun_and_grad, hessp = _prepare_vag_hessp( + fun, jac, hessp, fun_and_grad=fun_and_grad + ) + subproblem_kwargs = {} if subproblem_kwargs is None else subproblem_kwargs + cg_name = name + "SP" if name is not None else None + + f_0, g_0 = fun_and_grad(x0) + g_0_mag = jft_norm( + g_0, ord=subproblem_kwargs.get("norm_ord", 1), ravel=True + ) + status = jnp.where(jnp.isfinite(g_0_mag), status, 2) + init_params = _TrustRegionState( + converged=False, + status=status, + nit=0, + x=x0, + fun=f_0, + jac=g_0, + jac_magnitude=g_0_mag, + nfev=1, + njev=1, + nhev=0, + trust_radius=initial_trust_radius, + old_fval=old_fval + ) + + def _trust_region_body_f(params: _TrustRegionState) -> _TrustRegionState: + x_k, g_k, g_k_mag = params.x, params.jac, params.jac_magnitude + i, f_k, old_fval = params.nit, params.fun, params.old_fval + tr = params.trust_radius + + i += 1 + + if energy_reduction_factor: + cg_absdelta = energy_reduction_factor * (old_fval - f_k) + else: + cg_absdelta = None if absdelta is None else absdelta / 100. + cg_resnorm = jnp.minimum(0.5, jnp.sqrt(g_k_mag)) * g_k_mag + # TODO: add an internal success check for future subproblem approaches + # that might not be solvable + default_kwargs = { + "absdelta": cg_absdelta, + "resnorm": cg_resnorm, + "trust_radius": tr, + "norm_ord": 1, + "name": cg_name + } + sub_result = subproblem( + f_k, g_k, Partial(hessp, x_k), + **{**default_kwargs, **subproblem_kwargs} + ) + + pred_f_kp1 = sub_result.pred_f + x_kp1 = x_k + sub_result.step + f_kp1, g_kp1 = fun_and_grad(x_kp1) + + actual_reduction = f_k - f_kp1 + pred_reduction = f_k - pred_f_kp1 + + # update the trust radius according to the actual/predicted ratio + rho = actual_reduction / pred_reduction + tr_kp1 = jnp.where(rho < 0.25, tr * 0.25, tr) + tr_kp1 = jnp.where( + (rho > 0.75) & sub_result.hits_boundary, + jnp.minimum(2. * tr, max_trust_radius), tr_kp1 + ) + + # compute norm to check for convergence + g_kp1_mag = jft_norm( + g_kp1, ord=subproblem_kwargs.get("norm_ord", 1), ravel=True + ) + + # if the ratio is high enough then accept the proposed step + f_kp1, x_kp1, g_kp1, g_kp1_mag = where( + rho > eta, (f_kp1, x_kp1, g_kp1, g_kp1_mag), + (f_k, x_k, g_k, g_k_mag) + ) + + # Check whether we arrived at the float precision + energy_eps = eps * jnp.abs(f_kp1) + converged = (actual_reduction <= + energy_eps) & (actual_reduction > -energy_eps) + + converged |= g_kp1_mag < gtol + if absdelta: + converged |= (rho > eta) & (actual_reduction > + 0.) & (actual_reduction < absdelta) + + status = jnp.where(converged, 0, params.status) + status = jnp.where(i >= maxiter, 1, status) + status = jnp.where(pred_reduction <= 0, 2, status) + params = _TrustRegionState( + converged=converged, + nit=i, + x=x_kp1, + fun=f_kp1, + jac=g_kp1, + jac_magnitude=g_kp1_mag, + nfev=params.nfev + sub_result.nfev + 1, + njev=params.njev + sub_result.njev + 1, + nhev=params.nhev + sub_result.nhev, + trust_radius=tr_kp1, + status=status, + old_fval=f_k + ) + if name is not None: + from jax.experimental.host_callback import call + + def pp(arg): + i = arg["i"] + msg = ( + "{name}: ↗:{tr:.6e} ⬤:{hit} ∝:{rho:.2e} #∇²:{nhev:02d}" + "\n{name}: Iteration {i} ⛰:{energy:+.6e} Δ⛰:{energy_diff:.6e}" + + (" 🞋:{absdelta:.6e}" if absdelta is not None else "") + ( + "\n{name}: Iteration Limit Reached" + if i == maxiter else "" + ) + ) + print(msg.format(name=name, **arg), file=sys.stderr) + + printable_state = { + "i": i, + "energy": params.fun, + "energy_diff": actual_reduction, + "maxiter": maxiter, + "absdelta": absdelta, + "tr": params.trust_radius, + "rho": rho, + "nhev": params.nhev, + "hit": sub_result.hits_boundary + } + call(pp, printable_state, result_shape=None) + return params + + def _trust_region_cond_f(params: _TrustRegionState) -> bool: + return jnp.logical_not(params.converged) & (params.status == 0) + + state = lax.while_loop( + _trust_region_cond_f, _trust_region_body_f, init_params + ) + + return OptimizeResults( + success=state.converged & (state.status == 0), + nit=state.nit, + x=state.x, + fun=state.fun, + jac=state.jac, + nfev=state.nfev, + njev=state.njev, + nhev=state.nhev, + jac_magnitude=state.jac_magnitude, + trust_radius=state.trust_radius, + status=state.status + ) + + +def trust_ncg(fun=None, x0=None, *args, **kwargs): + if x0 is not None: + assert_arithmetics(x0) + return _trust_ncg(fun, x0, *args, **kwargs).x + + +def minimize( + fun: Optional[Callable[..., float]], + x0, + args: Tuple = (), + *, + method: str, + tol: Optional[float] = None, + options: Optional[Mapping[str, Any]] = None +) -> OptimizeResults: + """Minimize fun.""" + assert_arithmetics(x0) + if options is None: + options = {} + if not isinstance(args, tuple): + te = f"args argument must be a tuple, got {type(args)!r}" + raise TypeError(te) + + fun_with_args = fun + if args: + fun_with_args = lambda x: fun(x, *args) + + if tol is not None: + raise ValueError("use solver-specific options") + + if method.lower() in ('newton-cg', 'newtoncg', 'ncg'): + return _newton_cg(fun_with_args, x0, **options) + elif method.lower() in ('trust-ncg', 'trustncg'): + return _trust_ncg(fun_with_args, x0, **options) + + raise ValueError(f"method {method} not recognized") diff --git a/src/re/refine.py b/src/re/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..b95572e5481f17e8ec05ef37ccc998b62b6fb923 --- /dev/null +++ b/src/re/refine.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +from math import ceil +from string import ascii_uppercase +from typing import Callable, Literal, Optional, Union + +from jax import vmap +from jax import numpy as jnp +from jax.lax import conv_general_dilated, dynamic_slice, fori_loop +import numpy as np + +NDARRAY = Union[jnp.ndarray, np.ndarray] +# N - batch dimension +# C - feature dimension of data (channel) +# I - input dimension of kernel +# O - output dimension of kernel +CONV_DIMENSION_NAMES = "".join(el for el in ascii_uppercase if el not in "NCIO") + + +def _assert(assertion): + if not assertion: + raise AssertionError() + + +def _get_cov_from_loc(kernel=None, + cov_from_loc=None + ) -> Callable[[NDARRAY, NDARRAY], NDARRAY]: + if cov_from_loc is None and callable(kernel): + + def cov_from_loc_sngl(x, y): + return kernel(jnp.linalg.norm(x - y)) + + cov_from_loc = vmap( + vmap(cov_from_loc_sngl, in_axes=(None, 0)), in_axes=(0, None) + ) + else: + if not callable(cov_from_loc): + ve = "exactly one of `cov_from_loc` or `kernel` must be set and callable" + raise ValueError(ve) + # TODO: benchmark whether using `triu_indices(n, k=1)` and + # `diag_indices(n)` is advantageous + return cov_from_loc + + +def layer_refinement_matrices( + distances, + kernel: Optional[Callable] = None, + cov_from_loc: Optional[Callable] = None, + *, + _coarse_size: int = 3, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", + _with_zeros: bool = False, +): + cov_from_loc = _get_cov_from_loc(kernel, cov_from_loc) + distances = jnp.asarray(distances) + # TODO: distances must be a tensor iff _coarse_size > 3 + # TODO: allow different grid sizes for different axis + csz = int(_coarse_size) # coarse size + if _coarse_size % 2 != 1: + raise ValueError("only odd numbers allowed for `_coarse_size`") + fsz = int(_fine_size) # fine size + if _fine_size % 2 != 0: + raise ValueError("only even numbers allowed for `_fine_size`") + + ndim = distances.size + csz_half = int((csz - 1) / 2) + gc = jnp.arange(-csz_half, csz_half + 1, dtype=float) + gc = distances.reshape(ndim, 1) * gc + gc = jnp.stack(jnp.meshgrid(*gc, indexing="ij"), axis=-1) + if _fine_strategy == "jump": + gf = jnp.arange(fsz, dtype=float) / fsz - 0.5 + 0.5 / fsz + elif _fine_strategy == "extend": + gf = jnp.arange(fsz, dtype=float) / 2 - 0.25 * (fsz - 1) + else: + raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}") + gf = distances.reshape(ndim, 1) * gf + gf = jnp.stack(jnp.meshgrid(*gf, indexing="ij"), axis=-1) + # On the GPU a single `cov_from_loc` call is about twice as fast as three + # separate calls for coarse-coarse, fine-fine and coarse-fine. + coord = jnp.concatenate( + (gc.reshape(-1, ndim), gf.reshape(-1, ndim)), axis=0 + ) + cov = cov_from_loc(coord, coord) + cov_ff = cov[-fsz**ndim:, -fsz**ndim:] + cov_fc = cov[-fsz**ndim:, :-fsz**ndim] + cov_cc = cov[:-fsz**ndim, :-fsz**ndim] + cov_cc_inv = jnp.linalg.inv(cov_cc) + + olf = cov_fc @ cov_cc_inv + # Also see Schur-Complement + if _with_zeros: + r = jnp.linalg.norm(gc.reshape(-1, ndim), axis=1) + r_cutoff = jnp.max(distances) * csz_half + # dampening is chosen somewhat arbitrarily + r_dampening = jnp.max(distances)**-ndim + olf_wgt_sphere = jnp.where( + r <= r_cutoff, 1., + jnp.exp(-r_dampening * jnp.abs(r - r_cutoff)**ndim) + ) + olf *= olf_wgt_sphere[jnp.newaxis, ...] + fine_kernel = cov_ff - olf @ cov_cc @ olf.T + else: + fine_kernel = cov_ff - cov_fc @ cov_cc_inv @ cov_fc.T + # Implicitly assume a white power spectrum beyond the numerics limit. Use + # the diagonal as estimate for the magnitude of the variance. + fine_kernel_fallback = jnp.diag(jnp.abs(jnp.diag(fine_kernel))) + # Never produce NaNs (https://github.com/google/jax/issues/1052) + fine_kernel = jnp.where( + jnp.all(jnp.diag(fine_kernel) > 0.), fine_kernel, fine_kernel_fallback + ) + fine_kernel_sqrt = jnp.linalg.cholesky(fine_kernel) + + return olf, fine_kernel_sqrt + + +def refinement_matrices( + shape0, + depth, + distances0, + kernel: Optional[Callable] = None, + cov_from_loc: Optional[Callable] = None, + *, + _coarse_size: int = 3, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", + **kwargs, +): + cov_from_loc = _get_cov_from_loc(kernel, cov_from_loc) + + shape0 = np.atleast_1d(shape0) + distances0 = jnp.atleast_1d(distances0) + if shape0.shape != distances0.shape: + ve = ( + f"shape of `shape0` {shape0.shape} is incompatible with" + f" shape of `distances0` {distances0.shape}" + ) + raise ValueError(ve) + c0 = [d * jnp.arange(sz, dtype=float) for d, sz in zip(distances0, shape0)] + coord0 = jnp.stack(jnp.meshgrid(*c0, indexing="ij"), axis=-1) + coord0 = coord0.reshape(-1, len(shape0)) + cov_sqrt0 = jnp.linalg.cholesky(cov_from_loc(coord0, coord0)) + + if _fine_strategy == "jump": + dist_by_depth = distances0 / _fine_size**jnp.arange(0, depth + ).reshape(-1, 1) + elif _fine_strategy == "extend": + dist_by_depth = distances0 / 2**jnp.arange(0, depth).reshape(-1, 1) + else: + raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}") + olaf = partial( + layer_refinement_matrices, + cov_from_loc=cov_from_loc, + _coarse_size=_coarse_size, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy, + **kwargs + ) + opt_lin_filter, kernel_sqrt = vmap(olaf, in_axes=0, + out_axes=(0, 0))(dist_by_depth) + return opt_lin_filter, (cov_sqrt0, kernel_sqrt) + + +def _vmap_squeeze_first(fun, *args, **kwargs): + vfun = vmap(fun, *args, **kwargs) + + def vfun_apply(*x): + return vfun(jnp.squeeze(x[0], axis=0), *x[1:]) + + return vfun_apply + + +def refine_conv_general( + coarse_values, + excitations, + olf, + fine_kernel_sqrt, + precision=None, + _coarse_size: int = 3, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", +): + ndim = np.ndim(coarse_values) + # Introduce an artificial channel dimension for the matrix product + # TODO: allow different grid sizes for different axis + csz = int(_coarse_size) # coarse size + if _coarse_size % 2 != 1: + raise ValueError("only odd numbers allowed for `_coarse_size`") + fsz = int(_fine_size) # fine size + if _fine_size % 2 != 0: + raise ValueError("only even numbers allowed for `_fine_size`") + if olf.shape[:-2] != fine_kernel_sqrt.shape[:-2]: + ve = ( + "incompatible optimal linear filter (`olf`) and `fine_kernel_sqrt` shapes" + f"; got {olf.shape} and {fine_kernel_sqrt.shape}" + ) + raise ValueError(ve) + if olf.ndim > 2: + irreg_shape = olf.shape[:-2] + elif olf.ndim == 2: + irreg_shape = (1, ) * ndim + else: + ve = f"invalid shape of optimal linear filter (`olf`); got {olf.shape}" + raise ValueError(ve) + olf = olf.reshape( + irreg_shape + (fsz**ndim, ) + (csz, ) * (ndim - 1) + (1, csz) + ) + fine_kernel_sqrt = fine_kernel_sqrt.reshape(irreg_shape + (fsz**ndim, ) * 2) + + if _fine_strategy == "jump": + window_strides = (1, ) * ndim + fine_init_shape = tuple(n - (csz - 1) + for n in coarse_values.shape) + (fsz**ndim, ) + fine_final_shape = tuple( + fsz * (n - (csz - 1)) for n in coarse_values.shape + ) + convolution_slices = list(range(csz)) + elif _fine_strategy == "extend": + window_strides = (fsz // 2, ) * ndim + fine_init_shape = tuple( + ceil((n - (csz - 1)) / (fsz // 2)) for n in coarse_values.shape + ) + (fsz**ndim, ) + fine_final_shape = tuple( + fsz * ceil((n - (csz - 1)) / (fsz // 2)) + for n in coarse_values.shape + ) + convolution_slices = list(range(0, csz * fsz // 2, fsz // 2)) + + if fsz // 2 > csz: + ve = "extrapolation is not allowed (use `fine_size / 2 <= coarse_size`)" + raise ValueError(ve) + else: + raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}") + + if ndim > len(CONV_DIMENSION_NAMES): + ve = f"convolution for {ndim} dimensions not yet implemented" + raise ValueError(ve) + dim_names = CONV_DIMENSION_NAMES[:ndim] + conv = partial( + conv_general_dilated, + window_strides=window_strides, + padding="valid", + # channel-last layout is most efficient for vision models (at least in + # PyTorch) + dimension_numbers=( + f"N{dim_names}C", f"O{dim_names}I", f"N{dim_names}C" + ), + precision=precision, + ) + + c_shp_n1 = coarse_values.shape[-1] + c_slc_shp = (1, ) + c_slc_shp += tuple( + c if i == 1 else csz + for i, c in zip(irreg_shape, coarse_values.shape[:-1]) + ) + c_slc_shp += (-1, csz) + + fine = jnp.zeros(fine_init_shape) + PLC = -1 << 31 # integer placeholder outside of the here encountered regimes + irreg_indices = jnp.stack( + jnp.meshgrid( + *[ + jnp.arange(sz) if sz != 1 else jnp.array([PLC]) + for sz in irreg_shape + ], + indexing="ij" + ), + axis=-1 + ) + + def single_refinement_step(i, fine: jnp.ndarray) -> jnp.ndarray: + irreg_idx = jnp.unravel_index(i, irreg_indices.shape[:-1]) + _assert( + len(irreg_shape) == len(irreg_indices[irreg_idx]) == + len(window_strides) + ) + fine_init_idx = tuple( + idx if sz != 1 else slice(None) + for sz, idx in zip(irreg_shape, irreg_indices[irreg_idx]) + ) + # Make JAX/XLA happy with `dynamic_slice` + coarse_idx = tuple( + (ws * idx, csz) if sz != 1 else (0, cend) + for ws, sz, idx, cend in zip( + window_strides, irreg_shape, irreg_indices[irreg_idx], + coarse_values.shape + ) + ) + coarse_idx_select = partial( + dynamic_slice, + start_indices=list(zip(*coarse_idx))[0], + slice_sizes=list(zip(*coarse_idx))[1] + ) + + olf_at_i = jnp.squeeze( + olf[fine_init_idx], + axis=tuple(range(sum(i == 1 for i in irreg_shape))) + ) + if irreg_shape[-1] == 1 and fine_init_shape[-1] != 1: + _assert(fine_init_idx[-1] == slice(None)) + # loop over conv channel offsets to apply the filter matrix in a convolution + for i_f, i_c in enumerate(convolution_slices): + c = conv( + coarse_idx_select(coarse_values)[..., i_c:c_shp_n1 - + (c_shp_n1 - i_c) % + csz].reshape(c_slc_shp), + olf_at_i + )[0] + c = jnp.squeeze( + c, + axis=tuple(a for a, i in enumerate(irreg_shape) if i != 1) + ) + toti = fine_init_idx[:-1] + (slice(i_f, None, csz), ) + fine = fine.at[toti].set(c) + else: + _assert( + not isinstance(fine_init_idx[-1], slice) and + fine_init_idx[-1].ndim == 0 + ) + c = conv( + coarse_idx_select(coarse_values).reshape(c_slc_shp), olf_at_i + )[0] + c = jnp.squeeze( + c, axis=tuple(a for a, i in enumerate(irreg_shape) if i != 1) + ) + fine = fine.at[fine_init_idx].set(c) + + return fine + + fine = fori_loop( + 0, np.prod(irreg_indices.shape[:-1]), single_refinement_step, fine + ) + + matmul = partial(jnp.matmul, precision=precision) + for i in irreg_shape[::-1]: + if i != 1: + matmul = vmap(matmul, in_axes=(0, 0)) + else: + matmul = _vmap_squeeze_first(matmul, in_axes=(None, 0)) + m = matmul(fine_kernel_sqrt, excitations.reshape(fine_init_shape)) + rm_axs = tuple( + ax for ax, i in enumerate(m.shape[len(irreg_shape):], len(irreg_shape)) + if i == 1 + ) + fine += jnp.squeeze(m, axis=rm_axs) + + fine = fine.reshape(fine.shape[:-1] + (fsz, ) * ndim) + ax_label = np.arange(2 * ndim) + ax_t = [e for els in zip(ax_label[:ndim], ax_label[ndim:]) for e in els] + fine = jnp.transpose(fine, axes=ax_t) + + return fine.reshape(fine_final_shape) + + +def refine_slice( + coarse_values, + excitations, + olf, + fine_kernel_sqrt, + precision=None, + _coarse_size: int = 3, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", +): + ndim = np.ndim(coarse_values) + csz = int(_coarse_size) # coarse size + if _coarse_size % 2 != 1: + raise ValueError("only odd numbers allowed for `_coarse_size`") + fsz = int(_fine_size) # fine size + if _fine_size % 2 != 0: + raise ValueError("only even numbers allowed for `_fine_size`") + + if olf.shape[:-2] != fine_kernel_sqrt.shape[:-2]: + ve = ( + "incompatible optimal linear filter (`olf`) and `fine_kernel_sqrt` shapes" + f"; got {olf.shape} and {fine_kernel_sqrt.shape}" + ) + raise ValueError(ve) + if olf.ndim > 2: + irreg_shape = olf.shape[:-2] + elif olf.ndim == 2: + irreg_shape = (1, ) * ndim + else: + ve = f"invalid shape of optimal linear filter (`olf`); got {olf.shape}" + raise ValueError(ve) + olf = olf.reshape(irreg_shape + (fsz**ndim, ) + (csz, ) * ndim) + fine_kernel_sqrt = fine_kernel_sqrt.reshape(irreg_shape + (fsz**ndim, ) * 2) + + if _fine_strategy == "jump": + window_strides = (1, ) * ndim + fine_init_shape = tuple(n - (csz - 1) + for n in coarse_values.shape) + (fsz**ndim, ) + fine_final_shape = tuple( + fsz * (n - (csz - 1)) for n in coarse_values.shape + ) + elif _fine_strategy == "extend": + window_strides = (fsz // 2, ) * ndim + fine_init_shape = tuple( + ceil((n - (csz - 1)) / (fsz // 2)) for n in coarse_values.shape + ) + (fsz**ndim, ) + fine_final_shape = tuple( + fsz * ceil((n - (csz - 1)) / (fsz // 2)) + for n in coarse_values.shape + ) + + if fsz // 2 > csz: + ve = "extrapolation is not allowed (use `fine_size / 2 <= coarse_size`)" + raise ValueError(ve) + else: + raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}") + + def matmul_with_window_into(x, y, idx): + return jnp.tensordot( + x, + dynamic_slice(y, idx, slice_sizes=(csz, ) * ndim), + axes=ndim, + precision=precision + ) + + filter_coarse = matmul_with_window_into + corr_fine = partial(jnp.matmul, precision=precision) + for i in irreg_shape[::-1]: + if i != 1: + filter_coarse = vmap(filter_coarse, in_axes=(0, None, 1)) + corr_fine = vmap(corr_fine, in_axes=(0, 0)) + else: + filter_coarse = _vmap_squeeze_first(filter_coarse, in_axes=(None, None, 1)) + corr_fine = _vmap_squeeze_first(corr_fine, in_axes=(None, 0)) + + cv_idx = np.mgrid[tuple( + slice(None, sz - csz + 1, ws) + for sz, ws in zip(coarse_values.shape, window_strides) + )] + fine = filter_coarse(olf, coarse_values, cv_idx) + + m = corr_fine(fine_kernel_sqrt, excitations.reshape(fine_init_shape)) + rm_axs = tuple( + ax for ax, i in enumerate(m.shape[len(irreg_shape):], len(irreg_shape)) + if i == 1 + ) + fine += jnp.squeeze(m, axis=rm_axs) + + fine = fine.reshape(fine.shape[:-1] + (fsz, ) * ndim) + ax_label = np.arange(2 * ndim) + ax_t = [e for els in zip(ax_label[:ndim], ax_label[ndim:]) for e in els] + fine = jnp.transpose(fine, axes=ax_t) + + return fine.reshape(fine_final_shape) + + +def refine_conv( + coarse_values, excitations, olf, fine_kernel_sqrt, precision=None +): + fine_m = vmap( + partial(jnp.convolve, mode="valid", precision=precision), + in_axes=(None, 0), + out_axes=0 + )(coarse_values, olf[::-1]) + fine_m = jnp.moveaxis(fine_m, (0, ), (1, )) + fine_std = vmap(jnp.matmul, in_axes=(None, 0))( + fine_kernel_sqrt, excitations.reshape(-1, fine_kernel_sqrt.shape[-1]) + ) + + return (fine_m + fine_std).ravel() + + +def refine_loop( + coarse_values, excitations, olf, fine_kernel_sqrt, precision=None +): + fine_m = [ + jnp.convolve(coarse_values, o, mode="valid", precision=precision) + for o in olf[::-1] + ] + fine_m = jnp.stack(fine_m, axis=1) + fine_std = vmap(jnp.matmul, in_axes=(None, 0))( + fine_kernel_sqrt, excitations.reshape(-1, fine_kernel_sqrt.shape[-1]) + ) + + return (fine_m + fine_std).ravel() + + +def refine_vmap( + coarse_values, excitations, olf, fine_kernel_sqrt, precision=None +): + sh0 = coarse_values.shape[0] + conv = vmap( + partial(jnp.matmul, precision=precision), in_axes=(None, 0), out_axes=0 + ) + fine_m = jnp.zeros((coarse_values.size - 2, 2)) + fine_m = fine_m.at[0::3].set( + conv(olf, coarse_values[:sh0 - sh0 % 3].reshape(-1, 3)) + ) + fine_m = fine_m.at[1::3].set( + conv(olf, coarse_values[1:sh0 - (sh0 - 1) % 3].reshape(-1, 3)) + ) + fine_m = fine_m.at[2::3].set( + conv(olf, coarse_values[2:sh0 - (sh0 - 2) % 3].reshape(-1, 3)) + ) + + fine_std = vmap(jnp.matmul, in_axes=(None, 0))( + fine_kernel_sqrt, excitations.reshape(-1, fine_kernel_sqrt.shape[-1]) + ) + + return (fine_m + fine_std).ravel() + + +refine = refine_slice diff --git a/src/re/refine_chart.py b/src/re/refine_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fbe7f6a4fb879664514482d68cbc959771b485 --- /dev/null +++ b/src/re/refine_chart.py @@ -0,0 +1,916 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from collections import namedtuple +from functools import partial +from typing import Callable, Iterable, Literal, Optional, Tuple, Union + +from jax import numpy as jnp +from jax import vmap +import numpy as np + +from .refine import _get_cov_from_loc, refine +from .refine_util import ( + coarse2fine_distances, + coarse2fine_shape, + fine2coarse_distances, + fine2coarse_shape, + get_refinement_shapewithdtype, +) + +DEPTH_RANGE = (0, 32) +MAX_SIZE0 = 1024 + + +class CoordinateChart(): + def __init__( + self, + min_shape: Optional[Iterable[int]] = None, + depth: Optional[int] = None, + *, + shape0: Optional[Iterable[int]] = None, + _coarse_size: int = 5, + _fine_size: int = 4, + _fine_strategy: Literal["jump", "extend"] = "extend", + rg2cart: Optional[Callable[[ + Iterable, + ], Iterable]] = None, + cart2rg: Optional[Callable[[ + Iterable, + ], Iterable]] = None, + regular_axes: Optional[Union[Iterable[int], Tuple]] = None, + irregular_axes: Optional[Union[Iterable[int], Tuple]] = None, + distances: Optional[Union[Iterable[float], float]] = None, + distances0: Optional[Union[Iterable[float], float]] = None, + ): + """Initialize a refinement chart. + + Parameters + ---------- + min_shape : + Minimal extent in pixels along each axes at the final refinement + level. + depth : + Number of refinement iterations. + shape0 : + Alternative to `min_shape` and specifies the extent in pixels along + each axes at the zeroth refinement level. + _coarse_size : + Number of coarse pixels which to refine to `_fine_size` fine + pixels. + _fine_size : + Number of fine pixels which to refine from `_coarse_size` coarse + pixels. + _fine_strategy : + Whether to space fine pixels solely within the centermost coarse + pixel ("jump"), or whether to always space them out s.t. each fine + pixels takes up half the Euclidean volume of a coarse pixel + ("extend"). + rg2cart : + Function to translate Euclidean points on a regular coordinate + system to the Cartesian coordinate system of the modeled points. + cart2rg : + Inverse of `rg2cart`. + regular_axes : + Informs the coordinate chart on symmetries within the Cartesian + coordinate system of the modeled points. If specified, refinement + matrices are broadcasted as need instead of recomputed. + irregular_axes : + Negative of `regular_axes`. Specifying either is sufficient. + distances : + Special case of a coordinate chart in which the regular grid points + are merely stretched or compressed. `distances` are used to set the + distance between points along every axes at the final refinement + level. + distances0: + Same as `distances` except that `distances0` refers to the + distances along every axes at the zeroth refinement level. + + Note + ---- + The functions `rg2cart` and `cart2rg` are always w.r.t. the grid at + zero depth. In other words, it is straight forward to increase the + resolution of an existing chart by simply increasing its depth. + However, extending a grid spatially is more cumbersome and is best done + via `shape0`. + """ + if depth is None: + if min_shape is None: + raise ValueError("specify `min_shape` to infer `depth`") + if shape0 is not None or distances0 is not None: + ve = "can not infer `depth` with `shape0` or `distances0` set" + raise ValueError(ve) + for depth in range(*DEPTH_RANGE): + shape0 = fine2coarse_shape( + min_shape, + depth=depth, + ceil_sizes=True, + _coarse_size=_coarse_size, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + if np.prod(shape0, dtype=int) <= MAX_SIZE0: + break + else: + ve = f"unable to find suitable `depth`; please specify manually" + raise ValueError(ve) + if depth < 0: + raise ValueError(f"invalid `depth`; got {depth!r}") + self._depth = depth + + if shape0 is None and min_shape is not None: + shape0 = fine2coarse_shape( + min_shape, + depth, + ceil_sizes=True, + _coarse_size=_coarse_size, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + elif shape0 is None: + raise ValueError("either `shape0` or `min_shape` must be specified") + self._shape0 = (shape0, ) if isinstance(shape0, int) else tuple(shape0) + self._shape = coarse2fine_shape( + shape0, + depth, + _coarse_size=_coarse_size, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + + if _fine_strategy not in ("jump", "extend"): + ve = f"invalid `_fine_strategy`; got {_fine_strategy}" + raise ValueError(ve) + + self._shape_at = partial( + coarse2fine_shape, + self.shape0, + _coarse_size=_coarse_size, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + + self._coarse_size = int(_coarse_size) + self._fine_size = int(_fine_size) + self._fine_strategy = _fine_strategy + + # Derived attributes + self._ndim = len(self.shape) + self._size = np.prod(self.shape, dtype=int) + + if rg2cart is None and cart2rg is None: + if distances0 is None and distances is None: + distances = jnp.ones((self.ndim, )) + distances0 = fine2coarse_distances( + distances, + depth, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + elif distances0 is not None: + distances0 = jnp.broadcast_to( + jnp.atleast_1d(distances0), (self.ndim, ) + ) + distances = coarse2fine_distances( + distances0, + depth, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + else: + distances = jnp.broadcast_to( + jnp.atleast_1d(distances), (self.ndim, ) + ) + distances0 = fine2coarse_distances( + distances, + depth, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + + def _rg2cart(x): + x = jnp.asarray(x) + return x * distances0.reshape((-1, ) + (1, ) * (x.ndim - 1)) + + def _cart2rg(x): + x = jnp.asarray(x) + return x / distances0.reshape((-1, ) + (1, ) * (x.ndim - 1)) + + if regular_axes is None and irregular_axes is None: + regular_axes = tuple(range(self.ndim)) + self._rg2cart = _rg2cart + self._cart2rg = _cart2rg + elif rg2cart is not None and cart2rg is not None: + c0 = jnp.mgrid[tuple(slice(s) for s in self.shape0)] + if not all( + jnp.allclose(r, c) for r, c in zip(cart2rg(rg2cart(c0)), c0) + ): + raise ValueError("`cart2rg` is not the inverse of `rg2cart`") + self._rg2cart = rg2cart + self._cart2rg = cart2rg + distances = distances0 = None + else: + ve = "invalid combination of `cart2rg`, `rg2cart` and `distances`" + raise ValueError(ve) + self.distances = distances + self.distances0 = distances0 + + self.distances_at = partial( + coarse2fine_distances, + self.distances0, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy + ) + + if regular_axes is None and irregular_axes is not None: + regular_axes = tuple(set(range(self.ndim)) - set(irregular_axes)) + elif regular_axes is not None and irregular_axes is None: + irregular_axes = tuple(set(range(self.ndim)) - set(regular_axes)) + elif regular_axes is None and irregular_axes is None: + regular_axes = () + irregular_axes = tuple(range(self.ndim)) + else: + if set(regular_axes) | set(irregular_axes) != set(range(self.ndim)): + ve = "`regular_axes` and `irregular_axes` do not span the full axes" + raise ValueError(ve) + if set(regular_axes) & set(irregular_axes) != set(): + ve = "`regular_axes` and `irregular_axes` must be exclusive" + raise ValueError(ve) + self._regular_axes = tuple(regular_axes) + self._irregular_axes = tuple(irregular_axes) + if len(self.regular_axes) + len(self.irregular_axes) != self.ndim: + ve = ( + f"length of regular_axes and irregular_axes" + f" ({len(self.regular_axes)} + {len(self.irregular_axes)} respectively)" + f" incompatible with overall dimension {self.ndim}" + ) + raise ValueError(ve) + + self._descr = { + "depth": self.depth, + "shape0": self.shape0, + "_coarse_size": self.coarse_size, + "_fine_size": self.fine_size, + "_fine_strategy": self.fine_strategy, + } + if distances0 is not None: + self._descr["distances0"] = tuple(distances0) + else: + self._descr["rg2cart"] = repr(rg2cart) + self._descr["cart2rg"] = repr(cart2rg) + self._descr["regular_axes"] = self.regular_axes + + @property + def shape(self): + """Shape at the final refinement level""" + return self._shape + + @property + def shape0(self): + """Shape at the zeroth refinement level""" + return self._shape0 + + @property + def size(self): + return self._size + + @property + def ndim(self): + return self._ndim + + @property + def depth(self): + return self._depth + + @property + def coarse_size(self): + return self._coarse_size + + @property + def fine_size(self): + return self._fine_size + + @property + def fine_strategy(self): + return self._fine_strategy + + @property + def regular_axes(self): + return self._regular_axes + + @property + def irregular_axes(self): + return self._irregular_axes + + def rg2cart(self, positions): + """Translates positions from the regular Euclidean coordinate system to + the (in general) irregular Cartesian coordinate system. + + Parameters + ---------- + positions : + Positions on a regular Euclidean coordinate system. + + Returns + ------- + positions : + Positions on an (in general) irregular Cartesian coordinate system. + + Note + ---- + This method is independent of the refinement level! + """ + return self._rg2cart(positions) + + def cart2rg(self, positions): + """Translates positions from the (in general) irregular Cartesian + coordinate system to the regular Euclidean coordinate system. + + Parameters + ---------- + positions : + Positions on an (in general) irregular Cartesian coordinate system. + + Returns + ------- + positions : + Positions on a regular Euclidean coordinate system. + + Note + ---- + This method is independent of the refinement level! + """ + return self._cart2rg(positions) + + def rgoffset(self, lvl: int) -> Tuple[float]: + """Calculate the offset on the regular Euclidean grid due to shrinking + of the grid with increasing refinement level. + + Parameters + ---------- + lvl : + Level of the refinement. + + Returns + ------- + offset : + The offset on the regular Euclidean grid along each axes. + + Note + ---- + Indices are assumed to denote the center of the pixels, i.e. the pixel + with index `0` is assumed to be at `(0., ) * ndim`. + """ + csz = self.coarse_size # abbreviations for readability + fsz = self.fine_size + + leftmost_center = 0. + # Assume the indices denote the center of the pixels, i.e. the pixel + # with index 0 is at (0., ) * ndim + if self.fine_strategy == "jump": + # for i in range(lvl): + # leftmost_center += ((csz - 1) / 2 - 0.5 + 0.5 / fsz) / fsz**i + lm0 = (csz - 1) / 2 - 0.5 + 0.5 / fsz + geo = (1. - fsz** + -lvl) / (1. - 1. / fsz) # sum(fsz**-i for i in range(lvl)) + leftmost_center = lm0 * geo + elif self.fine_strategy == "extend": + # for i in range(lvl): + # leftmost_center += ((csz - 1) / 2 - 0.25 * (fsz - 1)) / 2**i + lm0 = ((csz - 1) / 2 - 0.25 * (fsz - 1)) + geo = (1. - 2.**-lvl) * 2. # sum(fsz**-i for i in range(lvl)) + leftmost_center = lm0 * geo + else: + raise AssertionError() + return tuple((leftmost_center, ) * self.ndim) + + def ind2rg(self, indices: Iterable[Union[float, int]], + lvl: int) -> Tuple[float]: + """Converts pixel indices to a continuous regular Euclidean grid + coordinates. + + Parameters + ---------- + indices : + Indices of shape `(n_dim, n_indices)` into the NDArray at + refinement level `lvl` which to convert to points in our regular + Euclidean grid. + lvl : + Level of the refinement. + + Returns + ------- + rg : + Regular Euclidean grid coordinates of shape `(n_dim, n_indices)`. + """ + offset = self.rgoffset(lvl) + + if self.fine_strategy == "jump": + dvol = 1 / self.fine_size**lvl + elif self.fine_strategy == "extend": + dvol = 1 / 2**lvl + else: + raise AssertionError() + return tuple(off + idx * dvol for off, idx in zip(offset, indices)) + + def rg2ind( + self, + positions: Iterable[Union[float, int]], + lvl: int, + discretize: bool = True + ) -> Union[Tuple[float], Tuple[int]]: + """Converts continuous regular grid positions to pixel indices. + + Parameters + ---------- + positions : + Positions on the regular Euclidean coordinate system of shape + `(n_dim, n_indices)` at refinement level `lvl` which to convert to + indices in a NDArray at the refinement level `lvl`. + lvl : + Level of the refinement. + discretize : + Whether to round indices to the next closest integer. + + Returns + ------- + indices : + Indices into the NDArray at refinement level `lvl`. + """ + offset = self.rgoffset(lvl) + + if self.fine_strategy == "jump": + dvol = 1 / self.fine_size**lvl + elif self.fine_strategy == "extend": + dvol = 1 / 2**lvl + else: + raise AssertionError() + indices = tuple(pos / dvol - off for off, pos in zip(offset, positions)) + if discretize: + indices = tuple(jnp.rint(idx).astype(jnp.int32) for idx in indices) + return indices + + def ind2cart(self, indices: Iterable[Union[float, int]], lvl: int): + """Computes the Cartesian coordinates of a pixel given the indices of + it. + + Parameters + ---------- + indices : + Indices of shape `(n_dim, n_indices)` into the NDArray at + refinement level `lvl` which to convert to locations in our (in + general) irregular coordinate system of the modeled points. + lvl : + Level of the refinement. + + Returns + ------- + positions : + Positions in the (in general) irregular coordinate system of the + modeled points of shape `(n_dim, n_indices)`. + """ + return self.rg2cart(self.ind2rg(indices, lvl)) + + def cart2ind(self, positions, lvl, discretize=True): + """Computes the indices of a pixel given the Cartesian coordinates of + it. + + Parameters + ---------- + positions : + Positions on the Cartesian (in general) irregular coordinate system + of the modeled points of shape `(n_dim, n_indices)` at refinement + level `lvl` which to convert to indices in a NDArray at the + refinement level `lvl`. + lvl : + Level of the refinement. + discretize : + Whether to round indices to the next closest integer. + + Returns + ------- + indices : + Indices into the NDArray at refinement level `lvl`. + """ + return self.rg2ind(self.cart2rg(positions), lvl, discretize=discretize) + + def shape_at(self, lvl): + """Retrieves the shape at a given refinement level `lvl`.""" + return self._shape_at(lvl) + + def level_of(self, shape: Tuple[int]): + """Finds the refinement level at which the number of grid points + equate. + """ + if not isinstance(shape, tuple): + raise TypeError(f"invalid type of `shape`; got {type(shape)}") + + for lvl in range(self.depth + 1): + if shape == self.shape_at(lvl): + return lvl + else: + raise ValueError(f"invalid shape {shape!r}") + + def __repr__(self): + return f"{self.__class__.__name__}(**{self._descr})" + + def __eq__(self, other): + return repr(self) == repr(other) + + +RefinementMatrices = namedtuple( + "RefinementMatrices", ("filter", "propagator_sqrt", "cov_sqrt0") +) + + +class RefinementField(): + def __init__( + self, + *args, + kernel: Optional[Callable] = None, + dtype=None, + skip0: bool = False, + **kwargs + ): + """Initialize an Iterative Charted Refinement (ICR) field. + + There are multiple ways to initialize a charted refinement field. The + recommended way is to first instantiate a `CoordinateChart` and pass it + as first argument to this method. Alternatively, you may pass any and + all arguments of `CoordinateChart` also to this method and it will + instantiate the `CoordinateChart` for you and use it in the same way as + if directly specified. + + Parameters + ---------- + chart : CoordinateChart + The `CoordinateChart` with which to refine. + kernel : + Covariance kernel of the refinement field. + dtype : + Data-type of the excitations which to add during refining. + skip0 : + Whether to skip the first refinement level. This is useful to e.g. + stack multiple refinement fields on top of each other. + **kwargs : + Alternatively to `chart` any parameters accepted by + `CoordinateChart`. + """ + self._kernel = kernel + self._dtype = dtype + self._skip0 = skip0 + + if len(args) > 0 and isinstance(args[0], CoordinateChart): + if kwargs: + raise TypeError(f"expected no keyword arguments, got {kwargs}") + + if len(args) == 1: + self._chart, = args + elif len(args) == 2 and callable(args[1]) and kernel is None: + self._chart, self._kernel = args + elif len(args) == 3 and callable( + args[1] + ) and kernel is None and dtype is None: + self._chart, self._kernel, self._dtype = args + elif len(args) == 4 and callable( + args[1] + ) and kernel is None and dtype is None and skip0 == False: + self._chart, self._kernel, self._dtype, self._skip0 = args + else: + te = "got unexpected arguments in addition to CoordinateChart" + raise TypeError(te) + else: + self._chart = CoordinateChart(*args, **kwargs) + + @property + def kernel(self): + """Yields the kernel specified during initialization or throw a + `TypeError`. + """ + if self._kernel is None: + te = ( + "either specify a fixed kernel during initialization of the" + f" {self.__class__.__name__} class or provide one here" + ) + raise TypeError(te) + return self._kernel + + @property + def dtype(self): + """Yields the data-type of the excitations.""" + return jnp.float64 if self._dtype is None else self._dtype + + @property + def skip0(self): + """Whether to skip the zeroth refinement""" + return self._skip0 + + @property + def chart(self): + """Associated `CoordinateChart` with which to iterative refine.""" + return self._chart + + def matrices( + self, + kernel: Optional[Callable] = None, + depth: Optional[int] = None, + skip0: Optional[bool] = None, + **kwargs + ) -> RefinementMatrices: + """Computes the refinement matrices namely the optimal linear filter + and the square root of the information propagator (a.k.a. the square + root of the fine covariance matrix for the excitations) for all + refinement levels and all pixel indices in the coordinate chart. + + Parameters + ---------- + kernel : + Covariance kernel of the refinement field if not specified during + initialization. + depth : + Maximum refinement depth if different to the one of the `CoordinateChart`. + skip0 : + Whether to skip the first refinement level. + """ + kernel = self.kernel if kernel is None else kernel + depth = self.chart.depth if depth is None else depth + skip0 = self.skip0 if skip0 is None else skip0 + + return _coordinate_refinement_matrices( + self.chart, kernel=kernel, depth=depth, skip0=skip0, **kwargs + ) + + def matrices_at( + self, + level: int, + pixel_index: Optional[Iterable[int]] = None, + kernel: Optional[Callable] = None, + **kwargs + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes the refinement matrices namely the optimal linear filter + and the square root of the information propagator (a.k.a. the square + root of the fine covariance matrix for the excitations) at the + specified level and pixel index. + + Parameters + ---------- + level : + Refinement level. + pixel_index : + Index of the NDArray at the refinement level `level` which to + refine, i.e. use as center coarse pixel. + kernel : + Covariance kernel of the refinement field if not specified during + initialization. + """ + kernel = self.kernel if kernel is None else kernel + + return _coordinate_pixel_refinement_matrices( + self.chart, + level=level, + pixel_index=pixel_index, + kernel=kernel, + **kwargs + ) + + @property + def shapewithdtype(self): + """Yields the `ShapeWithDtype` of the primals.""" + return get_refinement_shapewithdtype( + shape0=self.chart.shape0, + depth=self.chart.depth, + dtype=self.dtype, + skip0=self.skip0, + _coarse_size=self.chart.coarse_size, + _fine_size=self.chart.fine_size, + _fine_strategy=self.chart.fine_strategy, + ) + + @staticmethod + def apply( + xi, + chart, + kernel: Union[Callable, RefinementMatrices], + *, + skip0: bool = False, + depth: Optional[int] = None, + coerce_fine_kernel: bool = True, + _refine: Optional[Callable] = None, + precision=None, + ): + """Static method to apply a refinement field given some excitations, a + chart and a kernel. + + Parameters + ---------- + xi : + Latent parameters which to use for refining. + chart : + Chart with which to refine. + kernel : + Covariance kernel with which to build the refinement matrices. + skip0 : + Whether to skip the first refinement level. + depth : + Refinement depth if different to the depth of the coordinate chart. + coerce_fine_kernel : + Whether to coerce the refinement matrices at scales at which the + kernel matrix becomes singular or numerically highly unstable. + precision : + See JAX's precision. + """ + depth = chart.depth if depth is None else depth + if depth != len(xi) - 1: + ve = ( + f"incompatible refinement depths of `xi` ({len(xi) - 1})" + f" and `depth` (of chart) {depth}" + ) + raise ValueError(ve) + + if isinstance(kernel, RefinementMatrices): + refinement = kernel + else: + refinement = _coordinate_refinement_matrices( + chart, + kernel=kernel, + depth=depth, + skip0=skip0, + coerce_fine_kernel=coerce_fine_kernel + ) + refine_w_chart = partial( + refine if _refine is None else _refine, + _coarse_size=chart.coarse_size, + _fine_size=chart.fine_size, + _fine_strategy=chart.fine_strategy, + precision=precision + ) + + if not skip0: + fine = (refinement.cov_sqrt0 @ xi[0].ravel()).reshape(xi[0].shape) + else: + if refinement.cov_sqrt0 is not None: + raise AssertionError() + fine = xi[0] + for x, olf, k in zip( + xi[1:], refinement.filter, refinement.propagator_sqrt + ): + fine = refine_w_chart(fine, x, olf, k) + return fine + + def __call__(self, xi, kernel=None, *, skip0=None, **kwargs): + """See `RefinementField.apply`.""" + kernel = self.kernel if kernel is None else kernel + skip0 = self.skip0 if skip0 is None else skip0 + return self.apply(xi, self.chart, kernel=kernel, skip0=skip0, **kwargs) + + def __repr__(self): + descr = f"{self.__class__.__name__}({self.chart!r}" + descr += f", kernel={self._kernel!r}" if self._kernel is not None else "" + descr += f", dtype={self._dtype!r}" if self._dtype is not None else "" + descr += f", skip0={self.skip0!r}" if self.skip0 is not False else "" + descr += ")" + return descr + + def __eq__(self, other): + return repr(self) == repr(other) + + +def _coordinate_pixel_refinement_matrices( + chart: CoordinateChart, + level: int, + pixel_index: Optional[Iterable[int]] = None, + kernel: Optional[Callable] = None, + *, + coerce_fine_kernel: bool = True, + _cov_from_loc: Optional[Callable] = None, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + cov_from_loc = _get_cov_from_loc(kernel, _cov_from_loc) + csz = int(chart.coarse_size) # coarse size + if csz % 2 != 1: + raise ValueError("only odd numbers allowed for `_coarse_size`") + fsz = int(chart.fine_size) # fine size + if fsz % 2 != 0: + raise ValueError("only even numbers allowed for `_fine_size`") + ndim = chart.ndim + if pixel_index is None: + pixel_index = (0, ) * ndim + pixel_index = jnp.asarray(pixel_index) + if pixel_index.size != ndim: + ve = f"`pixel_index` has {pixel_index.size} dimensions but `chart` has {ndim}" + raise ValueError(ve) + + csz_half = int((csz - 1) / 2) + gc = jnp.arange(-csz_half, csz_half + 1, dtype=float) + gc = jnp.ones((ndim, 1)) * gc + gc = jnp.stack(jnp.meshgrid(*gc, indexing="ij"), axis=-1) + if chart.fine_strategy == "jump": + gf = jnp.arange(fsz, dtype=float) / fsz - 0.5 + 0.5 / fsz + elif chart.fine_strategy == "extend": + gf = jnp.arange(fsz, dtype=float) / 2 - 0.25 * (fsz - 1) + else: + raise ValueError(f"invalid `_fine_strategy`; got {chart.fine_strategy}") + gf = jnp.ones((ndim, 1)) * gf + gf = jnp.stack(jnp.meshgrid(*gf, indexing="ij"), axis=-1) + # On the GPU a single `cov_from_loc` call is about twice as fast as three + # separate calls for coarse-coarse, fine-fine and coarse-fine. + coord = jnp.concatenate( + (gc.reshape(-1, ndim), gf.reshape(-1, ndim)), axis=0 + ) + coord = chart.ind2cart((coord + pixel_index.reshape((1, ndim))).T, level) + coord = jnp.stack(coord, axis=-1) + cov = cov_from_loc(coord, coord) + cov_ff = cov[-fsz**ndim:, -fsz**ndim:] + cov_fc = cov[-fsz**ndim:, :-fsz**ndim] + cov_cc = cov[:-fsz**ndim, :-fsz**ndim] + cov_cc_inv = jnp.linalg.inv(cov_cc) + + olf = cov_fc @ cov_cc_inv + # Also see Schur-Complement + fine_kernel = cov_ff - cov_fc @ cov_cc_inv @ cov_fc.T + if coerce_fine_kernel: + # Implicitly assume a white power spectrum beyond the numerics limit. + # Use the diagonal as estimate for the magnitude of the variance. + fine_kernel_fallback = jnp.diag(jnp.abs(jnp.diag(fine_kernel))) + # Never produce NaNs (https://github.com/google/jax/issues/1052) + # This is expensive but necessary (worse but cheaper: + # `jnp.all(jnp.diag(fine_kernel) > 0.)`) + is_pos_def = jnp.all(jnp.linalg.eigvalsh(fine_kernel) > 0) + fine_kernel = jnp.where(is_pos_def, fine_kernel, fine_kernel_fallback) + # NOTE, subsequently use the Cholesky decomposition, even though + # already having computed the eigenvalues, as to get consistent results + # across platforms + fine_kernel_sqrt = jnp.linalg.cholesky(fine_kernel) + + return olf, fine_kernel_sqrt + + +def _coordinate_refinement_matrices( + chart: CoordinateChart, + kernel: Callable, + *, + depth: Optional[int] = None, + skip0=False, + coerce_fine_kernel: bool = True, + _cov_from_loc=None +) -> RefinementMatrices: + cov_from_loc = _get_cov_from_loc(kernel, _cov_from_loc) + depth = chart.depth if depth is None else depth + + if not skip0: + rg0 = jnp.mgrid[tuple(slice(s) for s in chart.shape0)] + c0 = jnp.stack(chart.ind2cart(rg0, 0), axis=-1).reshape(-1, chart.ndim) + cov_sqrt0 = jnp.linalg.cholesky(cov_from_loc(c0, c0)) + else: + cov_sqrt0 = None + + opt_lin_filter, kernel_sqrt = [], [] + olf_at = vmap( + partial( + _coordinate_pixel_refinement_matrices, + chart, + coerce_fine_kernel=coerce_fine_kernel, + _cov_from_loc=cov_from_loc, + ), + in_axes=(None, 0), + out_axes=(0, 0) + ) + + for lvl in range(depth): + shape_lvl = chart.shape_at(lvl) + pixel_indices = [] + for ax in range(chart.ndim): + pad = (chart.coarse_size - 1) / 2 + if int(pad) != pad: + raise ValueError("`coarse_size` must be odd") + pad = int(pad) + if chart.fine_strategy == "jump": + stride = 1 + elif chart.fine_strategy == "extend": + stride = chart.fine_size / 2 + if int(stride) != stride: + raise ValueError("`fine_size` must be even") + stride = int(stride) + else: + raise AssertionError() + if ax in chart.irregular_axes: + pixel_indices.append( + jnp.arange(pad, shape_lvl[ax] - pad, stride) + ) + else: + pixel_indices.append(jnp.array([pad])) + pixel_indices = jnp.stack( + jnp.meshgrid(*pixel_indices, indexing="ij"), axis=-1 + ) + shape_filtered_lvl = pixel_indices.shape[:-1] + pixel_indices = pixel_indices.reshape(-1, chart.ndim) + + olf, ks = olf_at(lvl, pixel_indices) + shape_bc_lvl = tuple( + shape_filtered_lvl[i] if i in chart.irregular_axes else 1 + for i in range(chart.ndim) + ) + opt_lin_filter.append(olf.reshape(shape_bc_lvl + olf.shape[-2:])) + kernel_sqrt.append(ks.reshape(shape_bc_lvl + ks.shape[-2:])) + + return RefinementMatrices(opt_lin_filter, kernel_sqrt, cov_sqrt0) diff --git a/src/re/refine_util.py b/src/re/refine_util.py new file mode 100644 index 0000000000000000000000000000000000000000..443baeb00f6566a1751a70415b5be233a19c0c01 --- /dev/null +++ b/src/re/refine_util.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +from math import ceil +import sys +from typing import Callable, Iterable, Literal, Optional, Tuple, Union +from warnings import warn + +import jax +from jax import numpy as jnp +import numpy as np +from scipy.spatial import distance_matrix + +from .forest_util import zeros_like + + +def get_refinement_shapewithdtype( + shape0: Union[int, tuple], + depth: int, + dtype=None, + *, + _coarse_size: int = 3, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", + skip0: bool = False, +): + from .forest_util import ShapeWithDtype + + if depth < 0: + raise ValueError(f"invalid `depth`; got {depth!r}") + csz = int(_coarse_size) # coarse size + fsz = int(_fine_size) # fine size + + swd = partial(ShapeWithDtype, dtype=dtype) + + shape0 = (shape0, ) if isinstance(shape0, int) else shape0 + ndim = len(shape0) + exc_shp = [swd(shape0)] if not skip0 else [None] + if depth > 0: + if _fine_strategy == "jump": + exc_lvl = tuple(el - (csz - 1) for el in shape0) + (fsz**ndim, ) + elif _fine_strategy == "extend": + exc_lvl = tuple( + ceil((el - (csz - 1)) / (fsz // 2)) for el in shape0 + ) + (fsz**ndim, ) + else: + raise ValueError(f"invalid `_fine_strategy`; got {_fine_strategy}") + exc_shp += [swd(exc_lvl)] + for lvl in range(1, depth): + if _fine_strategy == "jump": + exc_lvl = tuple( + fsz * el - (csz - 1) for el in exc_shp[-1].shape[:-1] + ) + (fsz**ndim, ) + elif _fine_strategy == "extend": + exc_lvl = tuple( + ceil((fsz * el - (csz - 1)) / (fsz // 2)) + for el in exc_shp[-1].shape[:-1] + ) + (fsz**ndim, ) + else: + raise AssertionError() + if any(el <= 0 for el in exc_lvl): + ve = ( + f"`shape0` ({shape0}) with `depth` ({depth}) yield an" + f" invalid shape ({exc_lvl}) at level {lvl}" + ) + raise ValueError(ve) + exc_shp += [swd(exc_lvl)] + + return exc_shp + + +def coarse2fine_shape( + shape0: Union[int, Iterable[int]], + depth: int, + *, + _coarse_size: int = 3, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", +): + """Translates a coarse shape to its corresponding fine shape.""" + shape0 = (shape0, ) if isinstance(shape0, int) else shape0 + csz = int(_coarse_size) # coarse size + fsz = int(_fine_size) # fine size + if _fine_size % 2 != 0: + raise ValueError("only even numbers allowed for `_fine_size`") + + shape = [] + for shp in shape0: + sz_at = shp + for lvl in range(depth): + if _fine_strategy == "jump": + sz_at = fsz * (sz_at - (csz - 1)) + elif _fine_strategy == "extend": + sz_at = fsz * ceil((sz_at - (csz - 1)) / (fsz // 2)) + else: + ve = f"invalid `_fine_strategy`; got {_fine_strategy}" + raise ValueError(ve) + if sz_at <= 0: + ve = ( + f"`shape0` ({shape0}) with `depth` ({depth}) yield an" + f" invalid shape ({sz_at}) at level {lvl}" + ) + raise ValueError(ve) + shape.append(int(sz_at)) + return tuple(shape) + + +def fine2coarse_shape( + shape: Union[int, Iterable[int]], + depth: int, + *, + _coarse_size: int = 3, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", + ceil_sizes: bool = False, +): + """Translates a fine shape to its corresponding coarse shape.""" + shape = (shape, ) if isinstance(shape, int) else shape + csz = int(_coarse_size) # coarse size + fsz = int(_fine_size) # fine size + if _fine_size % 2 != 0: + raise ValueError("only even numbers allowed for `_fine_size`") + + shape0 = [] + for shp in shape: + sz_at = shp + for lvl in range(depth, 0, -1): + if _fine_strategy == "jump": + # solve for n: `fsz * (n - (csz - 1))` + sz_at = sz_at / fsz + (csz - 1) + elif _fine_strategy == "extend": + # solve for n: `fsz * ceil((n - (csz - 1)) / (fsz // 2))` + # NOTE, not unique because of `ceil`; use lower limit + sz_at_max = (sz_at / fsz) * (fsz // 2) + (csz - 1) + sz_at_min = ceil(sz_at_max - (fsz // 2 - 1)) + for sz_at_cand in range(sz_at_min, ceil(sz_at_max) + 1): + try: + shp_cand = coarse2fine_shape( + (sz_at_cand, ), + depth=depth - lvl + 1, + _coarse_size=csz, + _fine_size=fsz, + _fine_strategy=_fine_strategy + )[0] + except ValueError as e: + if "invalid shape" not in "".join(e.args): + ve = "unexpected behavior of `coarse2fine_shape`" + raise ValueError(ve) from e + shp_cand = -1 + if shp_cand >= shp: + sz_at = sz_at_cand + break + else: + ve = f"interval search within [{sz_at_min}, {ceil(sz_at_max)}] failed" + raise ValueError(ve) + else: + ve = f"invalid `_fine_strategy`; got {_fine_strategy}" + raise ValueError(ve) + + sz_at = ceil(sz_at) if ceil_sizes else sz_at + if sz_at != int(sz_at): + raise ValueError(f"invalid shape at level {lvl}") + shape0.append(int(sz_at)) + return tuple(shape0) + + +def coarse2fine_distances( + distances0: Union[float, Iterable[float]], + depth: int, + *, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", +): + """Translates coarse distances to its corresponding fine distances.""" + fsz = int(_fine_size) # fine size + if _fine_strategy == "jump": + fpx_in_cpx = fsz**depth + elif _fine_strategy == "extend": + fpx_in_cpx = 2**depth + else: + ve = f"invalid `_fine_strategy`; got {_fine_strategy}" + raise ValueError(ve) + + return jnp.atleast_1d(distances0) / fpx_in_cpx + + +def fine2coarse_distances( + distances: Union[float, Iterable[float]], + depth: int, + *, + _fine_size: int = 2, + _fine_strategy: Literal["jump", "extend"] = "jump", +): + """Translates fine distances to its corresponding coarse distances.""" + fsz = int(_fine_size) # fine size + if _fine_strategy == "jump": + fpx_in_cpx = fsz**depth + elif _fine_strategy == "extend": + fpx_in_cpx = 2**depth + else: + ve = f"invalid `_fine_strategy`; got {_fine_strategy}" + raise ValueError(ve) + + return jnp.atleast_1d(distances) * fpx_in_cpx + + +def _clipping_posdef_logdet(mat, msg_prefix=""): + sign, logdet = jnp.linalg.slogdet(mat) + if sign <= 0: + ve = "not positive definite; clipping eigenvalues" + warn(msg_prefix + ve) + eps = jnp.finfo(mat.dtype.type).eps + evs = jnp.linalg.eigvalsh(mat) + logdet = jnp.sum(jnp.log(jnp.clip(evs, a_min=eps * evs.max()))) + return logdet + + +def gauss_kl(cov_desired, cov_approx, *, m_desired=None, m_approx=None): + cov_t_dl = _clipping_posdef_logdet(cov_desired, msg_prefix="`cov_desired` ") + cov_a_dl = _clipping_posdef_logdet(cov_approx, msg_prefix="`cov_approx` ") + cov_a_inv = jnp.linalg.inv(cov_approx) + + kl = -cov_desired.shape[0] # number of dimensions + kl += cov_a_dl - cov_t_dl + jnp.trace(cov_a_inv @ cov_desired) + if m_approx is not None and m_desired is not None: + m_diff = m_approx - m_desired + kl += m_diff @ cov_a_inv @ m_diff + elif not (m_approx is None and m_approx is None): + ve = "either both or neither of `m_approx` and `m_desired` must be `None`" + raise ValueError(ve) + return 0.5 * kl + + +def refinement_covariance(chart, kernel, jit=True): + """Computes the implied covariance as modeled by the refinement scheme.""" + from .refine_chart import RefinementField + + cf = RefinementField(chart, kernel=kernel) + try: + cf_T = jax.linear_transpose(cf, cf.shapewithdtype) + cov_implicit = lambda x: cf(*cf_T(x)) + cov_implicit = jax.jit(cov_implicit) if jit else cov_implicit + _ = cov_implicit(jnp.zeros(chart.shape)) # Test transpose + except (NotImplementedError, AssertionError): + # Workaround JAX not yet implementing the transpose of the scanned + # refinement + _, cf_T = jax.vjp(cf, zeros_like(cf.shapewithdtype)) + cov_implicit = lambda x: cf(*cf_T(x)) + cov_implicit = jax.jit(cov_implicit) if jit else cov_implicit + + probe = jnp.zeros(chart.shape) + indices = np.indices(chart.shape).reshape(chart.ndim, -1) + cov_empirical = jax.lax.map( + lambda idx: cov_implicit(probe.at[tuple(idx)].set(1.)).ravel(), + indices.T + ).T # vmap over `indices` w/ `in_axes=1, out_axes=-1` + + return cov_empirical + + +def true_covariance(chart, kernel, depth=None): + """Computes the true covariance at the final grid.""" + depth = chart.depth if depth is None else depth + + c0_slc = tuple(slice(sz) for sz in chart.shape_at(depth)) + pos = jnp.stack(chart.ind2cart(jnp.mgrid[c0_slc], depth), + axis=-1).reshape(-1, chart.ndim) + dist_mat = distance_matrix(pos, pos) + return kernel(dist_mat) + + +def refinement_approximation_error( + chart, + kernel: Callable, + cutout: Optional[Union[slice, int, Tuple[slice], Tuple[int]]] = None, +): + """Computes the Kullback-Leibler (KL) divergence of the true covariance versus the + approximative one for a given kernel and shape of the fine grid. + + If the desired shape can not be matched, the next larger one is used and + the field is subsequently cropped to the desired shape. + """ + + suggested_min_shape = 2 * 4**chart.depth + if any(s <= suggested_min_shape for s in chart.shape): + msg = ( + f"shape {chart.shape} potentially too small" + f" (desired {(suggested_min_shape, ) * chart.ndim} (=`2*4^depth`))" + ) + warn(msg) + + cov_empirical = refinement_covariance(chart, kernel) + cov_truth = true_covariance(chart, kernel) + + if cutout is None and all(s > suggested_min_shape for s in chart.shape): + cutout = (suggested_min_shape, ) * chart.ndim + print( + f"cropping field (w/ shape {chart.shape}) to {cutout}", + file=sys.stderr + ) + if cutout is not None: + if isinstance(cutout, slice): + cutout = (cutout, ) * chart.ndim + elif isinstance(cutout, int): + cutout = (slice(cutout), ) * chart.ndim + elif isinstance(cutout, tuple): + if all(isinstance(el, slice) for el in cutout): + pass + elif all(isinstance(el, int) for el in cutout): + cutout = tuple(slice(el) for el in cutout) + else: + raise TypeError("elements of `cutout` of invalid type") + else: + raise TypeError("`cutout` of invalid type") + + cov_empirical = cov_empirical.reshape(chart.shape * 2)[cutout * 2] + cov_truth = cov_truth.reshape(chart.shape * 2)[cutout * 2] + sz = np.prod(cov_empirical.shape[:chart.ndim]) + if np.prod(cov_truth.shape[:chart.ndim]) != sz or not sz.dtype == int: + raise AssertionError() + cov_empirical = cov_empirical.reshape(sz, sz) + cov_truth = cov_truth.reshape(sz, sz) + + aux = { + "cov_empirical": cov_empirical, + "cov_truth": cov_truth, + } + return gauss_kl(cov_truth, cov_empirical), aux diff --git a/src/re/stats_distributions.py b/src/re/stats_distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..de6980b8196909ec2ca7f342698f3824afdfb2b8 --- /dev/null +++ b/src/re/stats_distributions.py @@ -0,0 +1,254 @@ +from typing import Callable, Optional + +from jax import numpy as jnp + + +def laplace_prior(alpha) -> Callable: + """ + Takes random normal samples and outputs samples distributed according to + + .. math:: + P(x|a) = exp(-|x|/a)/a/2 + + """ + from jax.scipy.stats import norm + + def standard_to_laplace(xi): + res = (xi < 0) * (norm.logcdf(xi) + jnp.log(2)) + res -= (xi > 0) * (norm.logcdf(-xi) + jnp.log(2)) + return res * alpha + + return standard_to_laplace + + +def normal_prior(mean, std) -> Callable: + """Match standard normally distributed random variables to non-standard + variables. + """ + def standard_to_normal(xi): + return mean + std * xi + + return standard_to_normal + + +def lognormal_moments(mean, std): + """Compute the cumulants a log-normal process would need to comply with the + provided mean and standard-deviation `std` + """ + if jnp.any(mean <= 0.): + raise ValueError(f"`mean` must be greater zero; got {mean!r}") + if jnp.any(std <= 0.): + raise ValueError(f"`std` must be greater zero; got {std!r}") + + logstd = jnp.sqrt(jnp.log1p((std / mean)**2)) + logmean = jnp.log(mean) - 0.5 * logstd**2 + return logmean, logstd + + +def lognormal_prior(mean, std) -> Callable: + """Moment-match standard normally distributed random variables to log-space + + Takes random normal samples and outputs samples distributed according to + + .. math:: + P(xi|mu,sigma) \\propto exp(mu + sigma * xi) + + such that the mean and standard deviation of the distribution matches the + specified values. + """ + standard_to_normal = normal_prior(*lognormal_moments(mean, std)) + + def standard_to_lognormal(xi): + return jnp.exp(standard_to_normal(xi)) + + return standard_to_lognormal + + +def lognormal_invprior(mean, std) -> Callable: + """Get the inverse transform to `lognormal_prior`.""" + ln_m, ln_std = lognormal_moments(mean, std) + + def lognormal_to_standard(y): + return (jnp.log(y) - ln_m) / ln_std + + return lognormal_to_standard + + +def uniform_prior(a_min=0., a_max=1.) -> Callable: + """Transform a standard normal into a uniform distribution. + + Parameters + ---------- + a_min : float + Minimum value. + a_max : float + Maximum value. + """ + from jax.scipy.stats import norm + + if a_min == 0. and a_max == 1.: + return norm.cdf + + scale = a_max - a_min + + def standard_to_uniform(xi): + return a_min + scale * norm.cdf(xi) + + return standard_to_uniform + + +def interpolator( + func: Callable, + xmin: float, + xmax: float, + *, + step: Optional[float] = None, + num: Optional[int] = None, + table_func: Optional[Callable] = None, + inv_table_func: Optional[Callable] = None, + return_inverse: Optional[bool] = False +): # Adapted from NIFTy + """ + Evaluate a function point-wise by interpolation. Can be supplied with a + table_func to increase the interpolation accuracy, Best results are + achieved when `lambda x: table_func(func(x))` is roughly linear. + + Parameters + ---------- + func : function + Function to interpolate. + xmin : float + The smallest value for which `func` will be evaluated. + xmax : float + The largest value for which `func` will be evaluated. + step : float + Distance between sampling points for linear interpolation. Either of + `step` or `num` must be specified. + num : int + The number of interpolation points. Either of `step` of `num` must be + specified. + table_func : function + Non-linear function applied to the tabulated function in order to + transform the table to a more linear space. + inv_table_func : function + Inverse of `table_func`. + return_inverse : bool + Whether to also return the interpolation of the inverse of `func`. Only + sensible if `func` is invertible. + """ + # from scipy.interpolate import CubicSpline + + if step is not None and num is not None: + ve = "either but not both of `step` and `num` must be specified" + raise ValueError(ve) + if step is not None: + xs = jnp.arange(xmin, xmax + step, step) + elif num is not None: + xs = jnp.linspace(xmin, xmax, num) + else: + ve = "either of `step` or `num` must be specified" + raise ValueError(ve) + + ys = func(xs) + if table_func is not None: + if inv_table_func is None: + raise ValueError("no `inv_table_func` specified") + ys = table_func(ys) + + # interpolator = CubicSpline(xs, ys) + # deriv = interpolator.derivative() + + def interp(x): + # res = interpolator(x) + res = jnp.interp(x, xs, ys) + if inv_table_func is not None: + res = inv_table_func(res) + return res + + if return_inverse: + + def inverse_interp(y): + if table_func is not None: + y = table_func(y) + return jnp.interp(y, ys, xs) + + return interp, inverse_interp + + return interp + + +def invgamma_prior(a, scale, loc=0., step=1e-2) -> Callable: + """Transform a standard normal into an inverse gamma distribution. + + The pdf of the inverse gamma distribution is defined as follows using + :math:`q` to denote the scale: + + .. math:: + + P(x|q, a) = \\frac{q^a}{\\Gamma(a)}x^{-a -1} + \\exp \\left(-\\frac{q}{x}\\right) + + That means that for large x the pdf falls off like :math:`x^{(-a -1)}`. + The mean of the pdf is at :math:`q / (a - 1)` if :math:`a > 1`. + The mode is :math:`q / (a + 1)`. + + This transformation is implemented as a linear interpolation which maps a + Gaussian onto an inverse gamma distribution. + + Parameters + ---------- + a : float + The shape-parameter of the inverse-gamma distribution. + scale : float + The scale-parameter of the inverse-gamma distribution. + loc : float + An option shift of the whole distribution. + step : float + Distance between sampling points for linear interpolation. + """ + from scipy.stats import invgamma, norm + + if not jnp.isscalar(a) or not jnp.isscalar(loc): + te = ( + "Shape `a` and location `loc` must be of scalar type" + f"; got {type(a)} and {type(loc)} respectively" + ) + raise TypeError(te) + if loc == 0.: + # Pull out `scale` to interpolate less + s2i = lambda x: invgamma.ppf(norm._cdf(x), a=a) + elif jnp.isscalar(scale): + s2i = lambda x: invgamma.ppf(norm._cdf(x), a=a, loc=loc, scale=scale) + else: + raise TypeError("`scale` may only be array-like for `loc == 0.`") + + xmin, xmax = -8.2, 8.2 # (1. - norm.cdf(8.2)) * 2 < 1e-15 + standard_to_invgamma_interp = interpolator( + s2i, xmin, xmax, step=step, table_func=jnp.log, inv_table_func=jnp.exp + ) + + def standard_to_invgamma(x): + # Allow for array-like `scale` without separate interpolations and only + # interpolate for shape `a` and `loc` + if loc == 0.: + return standard_to_invgamma_interp(x) * scale + return standard_to_invgamma_interp(x) + + return standard_to_invgamma + + +def invgamma_invprior(a, scale, loc=0., step=1e-2) -> Callable: + """Get the inverse transformation to `invgamma_prior`.""" + from scipy.stats import invgamma, norm + + xmin, xmax = -8.2, 8.2 # (1. - norm.cdf(8.2)) * 2 < 1e-15 + _, invgamma_to_standard = interpolator( + lambda x: invgamma.ppf(norm._cdf(x), a=a, loc=loc, scale=scale), + xmin, + xmax, + step=step, + table_func=jnp.log, + inv_table_func=jnp.exp, + return_inverse=True + ) + return invgamma_to_standard diff --git a/src/re/structured_kernel_interpolation.py b/src/re/structured_kernel_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2e388282aebc2cb3f3e03f2c4fa9d322983a99 --- /dev/null +++ b/src/re/structured_kernel_interpolation.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from typing import Callable, Optional, Tuple, Union + +import jax +from jax import numpy as jnp +import numpy as np + +from .correlated_field import get_fourier_mode_distributor, hartley + +NDArray = Union[jnp.ndarray, np.ndarray] + + +def interp_mat(grid_shape, grid_bounds, sampling_points, *, distances=None): + from scipy.sparse import coo_matrix # TODO: use only JAX w/o SciPy or NumPy + from jax.experimental.sparse import BCOO + + if sampling_points.ndim != 2: + ve = f"invalid dimension of sampling_points {sampling_points.ndim!r}" + raise ValueError(ve) + ndim, n_points = sampling_points.shape + if grid_bounds is not None and len(grid_bounds) != ndim: + ve = ( + f"grid_bounds of length {len(grid_bounds)} incompatible with" + " sampling_points of shape {sampling_points.shape!r}" + ) + raise ValueError(ve) + elif grid_bounds is not None: + offset = np.array(list(zip(*grid_bounds))[0]) + else: + offset = np.zeros(ndim) + if distances is not None and np.size(distances) != ndim: + ve = ( + f"distances of size {np.size(distances)} incompatible with" + " sampling_points of shape {sampling_points.shape!r}" + ) + raise ValueError(ve) + distances = np.asarray(distances) if distances is not None else None + if (distances is not None and grid_bounds + is not None) or (distances is None and grid_bounds is None): + raise ValueError("exactly one of `distances` or `grid_shape` expected") + elif grid_bounds is not None: + distances = np.array( + [(b[1] - b[0]) / sz for b, sz in zip(grid_bounds, grid_shape)] + ) + if distances is None: + raise AssertionError() + + mg = np.mgrid[(slice(0, 2), ) * ndim].reshape(ndim, -1) + pos = (sampling_points - offset.reshape(-1, 1)) / distances.reshape(-1, 1) + excess, pos = np.modf(pos) + pos = pos.astype(np.int64) + # max_index = np.array(grid_shape).reshape(-1, 1) + weights = np.zeros((2**ndim, n_points)) + ii = np.zeros((2**ndim, n_points), dtype=np.int64) + jj = np.zeros((2**ndim, n_points), dtype=np.int64) + for i in range(2**ndim): + weights[i, :] = np.prod( + np.abs(1 - mg[:, i].reshape(-1, 1) - excess), axis=0 + ) + fromi = (pos + mg[:, i].reshape(-1, 1)) # % max_index + ii[i, :] = np.arange(n_points) + jj[i, :] = np.ravel_multi_index(fromi, grid_shape) + + mat = coo_matrix( + (weights.ravel(), (ii.ravel(), jj.ravel())), + shape=(n_points, np.prod(grid_shape)) + ) + # BCOO( + # (weights.ravel(), jnp.stack((ii.ravel(), jj.ravel()), axis=1)), + # shape=(n_points, np.prod(grid_shape)) + # ) + return BCOO.from_scipy_sparse(mat) + + +class HarmonicSKI(): + def __init__( + self, + grid_shape: Tuple[int], + grid_bounds: Tuple[Tuple[float, float]], + sampling_points: NDArray, + harmonic_kernel: Optional[Callable] = None, + padding: float = 0.5, + subslice=None, + jitter: Union[bool, float, None] = True + ): + """Instantiate a KISS-GP model of the covariance using a harmonic + representation of the kernel. + + Parameters + ---------- + grid_shape : + Number of pixels along each axes of the inducing points within + `grid_bounds`. + grid_bounds : + Tuple of boundaries of length of the number of dimensions. The + boundaries should denote the leftmost and rightmost edge of the + modeling space. + sampling_points : + Locations of the modeled points within the grid. + harmonic_kernel : + Harmonically transformed kernel. + padding : + Padding factor which to apply along each axis. + subslice : + Slice of the inducing points which to use to model + `sampling_points`. By default, the subslice is determined by the + padding. + jitter : + Strength of the diagonal jitter which to add to the covariance. + """ + if jitter is True: + if sampling_points.dtype.type == np.float64: + self.jitter = 1e-8 + elif sampling_points.dtype.type == np.float32: + self.jitter = 1e-6 + else: + raise NotImplementedError() + elif jitter is False: + self.jitter = None + else: + self.jitter = jitter + + self.grid_unpadded_shape = np.asarray(grid_shape) + self.grid_unpadded_bounds = np.asarray(grid_bounds) + self.grid_unpadded_distances = jnp.diff( + self.grid_unpadded_bounds, axis=1 + ).ravel() / self.grid_unpadded_shape + self.grid_unpadded_total_volume = jnp.prod( + self.grid_unpadded_shape * self.grid_unpadded_distances + ) + self.w = interp_mat(grid_shape, grid_bounds, sampling_points) + + if padding is not None and padding != 0.: + pad = 1. + padding + grid_shape = np.asarray(grid_shape) + grid_shape_wpad = np.ceil(grid_shape * pad).astype(int) + scl = grid_shape_wpad / grid_shape + scl_end = jnp.diff(jnp.asarray(grid_bounds), axis=1).ravel() * scl + grid_bounds_wpad = jnp.asarray(grid_bounds) + grid_bounds_wpad = grid_bounds_wpad.at[:, 1].set( + grid_bounds_wpad[:, 0].ravel() + scl_end + ) + if subslice is None: + subslice = tuple(map(int, grid_shape)) + grid_shape = grid_shape_wpad + grid_bounds = grid_bounds_wpad + self.grid_shape = np.asarray(grid_shape) + self.grid_bounds = np.asarray(grid_bounds) + self.grid_distances = jnp.diff(self.grid_bounds, + axis=1).ravel() / self.grid_shape + self.grid_total_volume = jnp.prod(self.grid_shape * self.grid_distances) + + self.power_distributor, self.unique_mode_lengths, _ = get_fourier_mode_distributor( + self.grid_shape, self.grid_distances + ) + + if subslice is not None: + if isinstance(subslice, slice): + subslice = (subslice, ) * len(self.grid_shape) + elif isinstance(subslice, int): + subslice = (slice(subslice), ) * len(self.grid_shape) + elif isinstance(subslice, tuple): + if all(isinstance(el, slice) for el in subslice): + pass + elif all(isinstance(el, int) for el in subslice): + subslice = tuple(slice(el) for el in subslice) + else: + raise TypeError("elements of `subslice` of invalid type") + else: + raise TypeError("`subslice` of invalid type") + self.grid_subslice = subslice + + self._harmonic_kernel = harmonic_kernel + + @property + def harmonic_kernel(self) -> Callable: + """Yields the harmonic kernel specified during initialization or throw + a `TypeError`. + """ + if self._harmonic_kernel is None: + te = ( + "either specify a fixed harmonic kernel during initialization" + f" of the {self.__class__.__name__} class or provide one here" + ) + raise TypeError(te) + return self._harmonic_kernel + + def power(self, harmonic_kernel=None) -> NDArray: + if harmonic_kernel is None: + harmonic_kernel = self.harmonic_kernel + power = harmonic_kernel(self.unique_mode_lengths) + power *= self.grid_total_volume / self.grid_unpadded_total_volume + return power + + def amplitude(self, harmonic_kernel=None): + power = self.power(harmonic_kernel) + # Assume that the kernel scales linear with the total volume + return jnp.sqrt(power) + + def harmonic_transform(self, x) -> NDArray: + return 1. / self.grid_total_volume * hartley(x) + + def correlated_field(self, x, harmonic_kernel=None) -> NDArray: + amp = self.amplitude(harmonic_kernel) + f = self.harmonic_transform(amp[self.power_distributor] * x) + if self.grid_subslice is None: + return f + return f[self.grid_subslice] + + def sandwich(self, x, harmonic_kernel=None) -> NDArray: + if self.grid_subslice is None: + x_wpad = x + else: + x_wpad = jnp.zeros(tuple(self.grid_shape)) + x_wpad = x_wpad.at[self.grid_subslice].set(x) + + swd = jax.ShapeDtypeStruct(tuple(self.grid_shape), x.dtype) + ht = self.harmonic_transform + ht_T = jax.linear_transpose(self.harmonic_transform, swd) + + power = self.power(harmonic_kernel=harmonic_kernel) + s = ht(power[self.power_distributor] * ht_T(x_wpad)[0]) + if self.grid_subslice is None: + return s + return s[self.grid_subslice] + + def __call__(self, x, harmonic_kernel=None) -> NDArray: + """Applies the Covariance matrix.""" + x_shp = x.shape + jitter = 0. if self.jitter is None else self.jitter * x + + x = (self.w.T @ x.ravel()).reshape(tuple(self.grid_unpadded_shape)) + x = self.sandwich(x, harmonic_kernel=harmonic_kernel) + x = (self.w @ x.ravel()).reshape(x_shp) + return x + jitter + + def evaluate(self, harmonic_kernel=None): + """Instantiate the full covariance matrix.""" + probe = jnp.zeros(self.w.shape[0]) + indices = jnp.arange(self.w.shape[0]).reshape(1, -1) + + return jax.lax.map( + lambda idx: self( + probe.at[tuple(idx)].set(1.), harmonic_kernel=harmonic_kernel + ).ravel(), indices.T + ).T # vmap over `indices` w/ `in_axes=1, out_axes=-1` + + def evaluate_(self, kernel) -> NDArray: + from scipy.spatial import distance_matrix + + if self.jitter is None: + jitter = 0. + else: + jitter = self.jitter * jnp.eye(self.w.shape[0]) + + p = [ + np.linspace(*b, num=sz, endpoint=True) for b, sz in + zip(self.grid_unpadded_bounds, self.grid_unpadded_shape) + ] + p = np.stack(np.meshgrid(*p, indexing="ij"), + axis=-1).reshape(-1, len(self.grid_unpadded_shape)) + kernel_inducing = kernel(distance_matrix(p, p)) + + return self.w @ kernel_inducing @ self.w.T + jitter diff --git a/src/re/sugar.py b/src/re/sugar.py new file mode 100644 index 0000000000000000000000000000000000000000..33dfd0468366804c1b2f3dbc2cdba3110ff98195 --- /dev/null +++ b/src/re/sugar.py @@ -0,0 +1,137 @@ +# Copyright(C) 2013-2021 Max-Planck-Society +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from collections.abc import Iterable +from typing import Any, Callable, Dict, Hashable, Mapping, TypeVar, Union + +from jax import numpy as jnp +from jax import random +from jax.tree_util import tree_map, tree_reduce, tree_structure, tree_unflatten + +from .field import Field + +O = TypeVar('O') +I = TypeVar('I') + + +def isiterable(candidate): + try: + iter(candidate) + return True + except (TypeError, AttributeError): + return False + + +def is1d(ls: Any) -> bool: + """Indicates whether the input is one dimensional. + + An object is considered one dimensional if it is an iterable of + non-iterable items. + """ + if hasattr(ls, "ndim"): + return ls.ndim == 1 + if not isiterable(ls): + return False + return all(not isiterable(e) for e in ls) + + +def doc_from(original): + def wrapper(target): + target.__doc__ = original.__doc__ + return target + + return wrapper + + +def ducktape(call: Callable[[I], O], + key: Hashable) -> Callable[[Mapping[Hashable, I]], O]: + def named_call(p): + return call(p[key]) + + return named_call + + +def ducktape_left(call: Callable[[I], O], + key: Hashable) -> Callable[[I], Dict[Hashable, O]]: + def named_call(p): + return {key: call(p)} + + return named_call + + +def sum_of_squares(tree) -> Union[jnp.ndarray, jnp.inexact]: + return tree_reduce(jnp.add, tree_map(lambda x: jnp.sum(x**2), tree), 0.) + + +def mean(forest): + from functools import reduce + + norm = 1. / len(forest) + if isinstance(forest[0], Field): + m = norm * reduce(Field.__add__, forest) + return m + else: + m = norm * reduce(Field.__add__, (Field(t) for t in forest)) + return m.val + + +def mean_and_std(forest, correct_bias=True): + if isinstance(forest[0], Field): + m = mean(forest) + mean_of_sq = mean(tuple(t**2 for t in forest)) + else: + m = Field(mean(forest)) + mean_of_sq = Field(mean(tuple(Field(t)**2 for t in forest))) + + n = len(forest) + scl = jnp.sqrt(n / (n - 1)) if correct_bias else 1. + std = scl * tree_map(jnp.sqrt, mean_of_sq - m**2) + if isinstance(forest[0], Field): + return m, std + else: + return m.val, std.val + + +def random_like(key: Iterable, primals, rng: Callable = random.normal): + import numpy as np + + struct = tree_structure(primals) + # Cast the subkeys to the structure of `primals` + subkeys = tree_unflatten(struct, random.split(key, struct.num_leaves)) + + def draw(key, x): + shp = x.shape if hasattr(x, "shape") else jnp.shape(x) + dtp = x.dtype if hasattr(x, "dtype") else np.common_type(x) + return rng(key=key, shape=shp, dtype=dtp) + + return tree_map(draw, subkeys, primals) + + +def interpolate(xmin=-7., xmax=7., N=14000) -> Callable: + """Replaces a local nonlinearity such as jnp.exp with a linear interpolation + + Interpolating functions speeds up code and increases numerical stability in + some cases, but at a cost of precision and range. + + Parameters + ---------- + xmin : float + Minimal interpolation value. Default: -7. + xmax : float + Maximal interpolation value. Default: 7. + N : int + Number of points used for the interpolation. Default: 14000 + """ + def decorator(f): + from functools import wraps + + x = jnp.linspace(xmin, xmax, N) + y = f(x) + + @wraps(f) + def wrapper(t): + return jnp.interp(t, x, y) + + return wrapper + + return decorator diff --git a/src/sugar.py b/src/sugar.py index 7f5097f005b2340375d2d030bd9308f7b12238c4..36b7ac333e217ef50e63e4edcf78d33a300c0a23 100644 --- a/src/sugar.py +++ b/src/sugar.py @@ -60,7 +60,7 @@ def PS_field(pspace, function): Returns ------- - Field + :class:`nifty8.field.Field` A field defined on (pspace,) containing the computed function values """ if not isinstance(pspace, PowerSpace): @@ -119,7 +119,7 @@ def power_analyze(field, spaces=None, binbounds=None, Parameters ---------- - field : Field + field : :class:`nifty8.field.Field` The field to be analyzed spaces : None or int or tuple of int, optional The indices of subdomains for which the power spectrum shall be @@ -142,7 +142,7 @@ def power_analyze(field, spaces=None, binbounds=None, Returns ------- - Field + :class:`nifty8.field.Field` The output object. Its domain is a PowerSpace and it contains the power spectrum of `field`. """ @@ -203,7 +203,7 @@ def create_power_operator(domain, power_spectrum, space=None, ---------- domain : Domain, tuple of Domain or DomainTuple Domain on which the power operator shall be defined. - power_spectrum : callable or Field + power_spectrum : callable or :class:`nifty8.field.Field` An object that contains the power spectrum as a function of k. space : int the domain index on which the power operator will work @@ -318,7 +318,7 @@ def full(domain, val): Returns ------- - Field or MultiField + :class:`nifty8.field.Field` or:class:`nifty8.mulit_field.MultiField` The newly created uniform field """ if isinstance(domain, (dict, MultiDomain)): @@ -344,7 +344,7 @@ def from_random(domain, random_type='normal', dtype=np.float64, **kwargs): Returns ------- - Field or MultiField + :class:`nifty8.field.Field` or:class:`nifty8.mulit_field.MultiField` The newly created random field Notes @@ -372,7 +372,7 @@ def makeField(domain, arr): Returns ------- - Field or MultiField + :class:`nifty8.field.Field` or:class:`nifty8.mulit_field.MultiField` The newly created random field """ if isinstance(domain, (dict, MultiDomain)): @@ -405,7 +405,7 @@ def makeOp(inp, dom=None, sampling_dtype=None): Parameters ---------- - inp : None, Field or MultiField + inp : None, :class:`nifty8.field.Field` or :class:`nifty8.multi_field.MultiField` - if None, None is returned. - if Field on scalar-domain, a ScalingOperator with the coefficient given by the Field is returned. diff --git a/test/test_operators/test_interpolated.py b/test/test_operators/test_interpolated.py index eae1c230d3534992297a1b6dca7f48760a387973..432b7c71c3ccb124dd21f577a1aa549a64c3f1ac 100644 --- a/test/test_operators/test_interpolated.py +++ b/test/test_operators/test_interpolated.py @@ -25,7 +25,6 @@ import nifty8 as ift from ..common import list2fixture, setup_function, teardown_function -pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize space = list2fixture([ift.GLSpace(15), ift.RGSpace(64, distances=.789), @@ -34,6 +33,7 @@ seed = list2fixture([4, 78, 23]) def testInterpolationAccuracy(space, seed): + ift.random.push_sseq_from_seed(seed) pos = ift.from_random(space, 'normal') alpha = 1.5 qs = [0.73, pos.ptw("exp").val] diff --git a/test/test_re/test_energies.py b/test/test_re/test_energies.py new file mode 100644 index 0000000000000000000000000000000000000000..9df3cd1b5735164518d7c7df2c913622f719a60a --- /dev/null +++ b/test/test_re/test_energies.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +import jax.numpy as jnp +import pytest +from functools import partial +from jax import random +from jax.tree_util import tree_map +from numpy.testing import assert_allclose + +import nifty8.re as jft + +pmp = pytest.mark.parametrize + + +def lst2fixt(lst): + @pytest.fixture(params=lst) + def fixt(request): + return request.param + + return fixt + + +def random_noise_std_inv(key, shape): + diag = 1. / random.exponential(key, shape) + + def noise_std_inv(tangents): + return diag * tangents + + return noise_std_inv + + +seed = lst2fixt((3639, 12, 41, 42)) +shape = lst2fixt(((4, 2), (2, 1), (5, ))) +lh_init_true = ( + ( + jft.Gaussian, { + "data": random.normal, + "noise_std_inv": random_noise_std_inv + }, None + ), ( + jft.StudentT, { + "data": random.normal, + "dof": random.exponential, + "noise_std_inv": random_noise_std_inv + }, None + ), ( + jft.Poissonian, { + "data": partial(random.poisson, lam=3.14) + }, random.exponential + ) +) +lh_init_approx = ( + ( + jft.VariableCovarianceGaussian, { + "data": random.normal + }, lambda key, shape: ( + random.normal(key, shape=shape), 1. / jnp. + exp(random.normal(key, shape=shape)) + ) + ), ( + jft.VariableCovarianceStudentT, { + "data": random.normal, + "dof": random.exponential + }, lambda key, shape: ( + random.normal(key, shape=shape), + jnp.exp(1. + random.normal(key, shape=shape)) + ) + ) +) + + +def test_gaussian_vs_vcgaussian_consistency(seed, shape): + rtol = 10 * jnp.finfo(jnp.zeros(0).dtype).eps + atol = 5 * jnp.finfo(jnp.zeros(0).dtype).eps + + key = random.PRNGKey(seed) + sk = list(random.split(key, 5)) + d = random.normal(sk.pop(), shape=shape) + m1 = random.normal(sk.pop(), shape=shape) + m2 = random.normal(sk.pop(), shape=shape) + t = random.normal(sk.pop(), shape=shape) + inv_std = 1. / jnp.exp(1. + random.normal(sk.pop(), shape=shape)) + + gauss = jft.Gaussian(d, noise_std_inv=lambda x: inv_std * x) + vcgauss = jft.VariableCovarianceGaussian(d) + + diff_g = gauss(m2) - gauss(m1) + diff_vcg = vcgauss((m2, inv_std)) - vcgauss((m1, inv_std)) + assert_allclose(diff_g, diff_vcg, rtol=rtol, atol=atol) + + met_g = gauss.metric(m1, t) + met_vcg = vcgauss.metric((m1, inv_std), (t, d / 2))[0] + assert_allclose(met_g, met_vcg, rtol=rtol, atol=atol) + + +def test_studt_vs_vcstudt_consistency(seed, shape): + rtol = 10 * jnp.finfo(jnp.zeros(0).dtype).eps + atol = 4 * jnp.finfo(jnp.zeros(0).dtype).eps + + key = random.PRNGKey(seed) + sk = list(random.split(key, 6)) + d = random.normal(sk.pop(), shape=shape) + dof = random.normal(sk.pop(), shape=shape) + m1 = random.normal(sk.pop(), shape=shape) + m2 = random.normal(sk.pop(), shape=shape) + t = random.normal(sk.pop(), shape=shape) + inv_std = 1. / jnp.exp(1. + random.normal(sk.pop(), shape=shape)) + + studt = jft.StudentT(d, dof, noise_std_inv=lambda x: inv_std * x) + vcstudt = jft.VariableCovarianceStudentT(d, dof) + + diff_t = studt(m2) - studt(m1) + diff_vct = vcstudt((m2, 1. / inv_std)) - vcstudt((m1, 1. / inv_std)) + assert_allclose(diff_t, diff_vct, rtol=rtol, atol=atol) + + met_g = studt.metric(m1, t) + met_vcg = vcstudt.metric((m1, 1. / inv_std), (t, d / 2))[0] + assert_allclose(met_g, met_vcg, rtol=rtol, atol=atol) + + +@pmp("lh_init", lh_init_true + lh_init_approx) +def test_left_sqrt_metric_vs_metric_consistency(seed, shape, lh_init): + rtol = 4 * jnp.finfo(jnp.zeros(0).dtype).eps + atol = 0. + aallclose = partial(assert_allclose, rtol=rtol, atol=atol) + + N_TRIES = 5 + + lh_init_method, draw, latent_init = lh_init + key = random.PRNGKey(seed) + key, *subkeys = random.split(key, 1 + len(draw)) + init_kwargs = { + k: method(key=sk, shape=shape) + for (k, method), sk in zip(draw.items(), subkeys) + } + lh = lh_init_method(**init_kwargs) + + energy, lsm, lsm_shp = lh.energy, lh.left_sqrt_metric, lh.lsm_tangents_shape + # Let JIFTy infer the metric from the left-square-root-metric + lh_mini = jft.Likelihood( + energy, left_sqrt_metric=lsm, lsm_tangents_shape=lsm_shp + ) + + rng_method = latent_init if latent_init is not None else random.normal + for _ in range(N_TRIES): + key, *sk = random.split(key, 3) + p = rng_method(sk.pop(), shape=shape) + t = rng_method(sk.pop(), shape=shape) + tree_map(aallclose, lh.metric(p, t), lh_mini.metric(p, t)) + + +@pmp("lh_init", lh_init_true) +def test_transformation_vs_left_sqrt_metric_consistency(seed, shape, lh_init): + rtol = 4 * jnp.finfo(jnp.zeros(0).dtype).eps + atol = 0. + + N_TRIES = 5 + + lh_init_method, draw, latent_init = lh_init + key = random.PRNGKey(seed) + key, *subkeys = random.split(key, 1 + len(draw)) + init_kwargs = { + k: method(key=sk, shape=shape) + for (k, method), sk in zip(draw.items(), subkeys) + } + lh = lh_init_method(**init_kwargs) + if lh._transformation is None: + pytest.skip("no transformation rule implemented yet") + + energy, lsm, lsm_shp = lh.energy, lh.left_sqrt_metric, lh.lsm_tangents_shape + # Let JIFTy infer the left-square-root-metric and the metric from the + # transformation + lh_mini = jft.Likelihood( + energy, left_sqrt_metric=lsm, lsm_tangents_shape=lsm_shp + ) + + rng_method = latent_init if latent_init is not None else random.normal + for _ in range(N_TRIES): + key, *sk = random.split(key, 3) + p = rng_method(sk.pop(), shape=shape) + t = rng_method(sk.pop(), shape=shape) + assert_allclose( + lh.left_sqrt_metric(p, t), + lh_mini.left_sqrt_metric(p, t), + rtol=rtol, + atol=atol + ) + assert_allclose( + lh.metric(p, t), lh_mini.metric(p, t), rtol=rtol, atol=atol + ) + + +if __name__ == "__main__": + test_gaussian_vs_vcgaussian_consistency(42, (5, )) diff --git a/test/test_re/test_hmc_1d_distributions.py b/test/test_re/test_hmc_1d_distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb520c7211f11b8815a67ec55869b5807082268 --- /dev/null +++ b/test/test_re/test_hmc_1d_distributions.py @@ -0,0 +1,120 @@ +import sys + +from jax import numpy as jnp +from jax.scipy import stats +from numpy.testing import assert_allclose +import pytest +import scipy +from scipy.special import comb + +import nifty8.re as jft + +pmp = pytest.mark.parametrize + + +def mnc2mc(mnc, wmean=True): + """Convert non-central to central moments, uses recursive formula + optionally adjusts first moment to return mean. + """ + + # https://www.statsmodels.org/stable/_modules/statsmodels/stats/moment_helpers.html + def _local_counts(mnc): + mean = mnc[0] + mnc = [1] + list(mnc) # add zero moment = 1 + mu = [] + for n, m in enumerate(mnc): + mu.append(0) + for k in range(n + 1): + sgn_comb = (-1)**(n - k) * comb(n, k, exact=True) + mu[n] += sgn_comb * mnc[k] * mean**(n - k) + if wmean: + mu[1] = mean + return mu[1:] + + res = jnp.apply_along_axis(_local_counts, 0, mnc) + # for backward compatibility convert 1-dim output to list/tuple + return res + + +# Test simple distributions with no extra parameters +dists = [stats.cauchy, stats.expon, stats.laplace, stats.logistic, stats.norm] +# Tuple of `rtol` and `atol` for every tested moment +moments_tol = {1: (0., 2e-1), 2: (3e-1, 0.), 3: (4e-1, 8e-1), 4: (4., 0.)} + + +@pmp("distribution", dists) +def test_moment_consistency(distribution, plot=False): + name = distribution.__name__.split('.')[-1] + + max_tree_depth = 20 + sampler = jft.NUTSChain( + potential_energy=lambda x: -1 * distribution.logpdf(x), + inverse_mass_matrix=1., + position_proto=jnp.array(0.), + step_size=0.7193, + max_tree_depth=max_tree_depth, + ) + chain, _ = sampler.generate_n_samples( + 42, jnp.array(1.03890), num_samples=1000, save_intermediates=True + ) + + # unique, counts = jnp.unique(chain.depths, return_counts=True) + # depths_frequencies = jnp.asarray((unique, counts)).T + + if plot is True: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(1, 2) + + bins = jnp.linspace(-10, 10) + if distribution is stats.expon: + bins = jnp.linspace(0, 10) + axs.flat[0].hist( + chain.samples, bins=bins, density=True, histtype="step" + ) + axs.flat[0].plot(bins, distribution.pdf(bins), color='r') + axs.flat[0].set_title(f"{name} PDF") + + axs.flat[1].hist( + chain.depths, + bins=jnp.arange(max_tree_depth + 1), + density=True, + histtype="step" + ) + axs.flat[1].set_title(f"Tree-depth") + fig.tight_layout() + plt.show() + + # central moments; except for the first (i.e. mean) + sample_moms_central = scipy.stats.moment(chain.samples, [1, 2, 3, 4, 5, 6]) + sample_moms_central[0] = jnp.mean(chain.samples) + + scipy_dist = getattr(scipy.stats, name) + dist_moms_non_central = jnp.array( + [scipy_dist.moment(i) for i in [1, 2, 3, 4, 5, 6]] + ) + dist_moms_central = mnc2mc(dist_moms_non_central, wmean=True) + + for i, (smpl_mom, dist_mom) in enumerate( + zip(sample_moms_central, dist_moms_central), start=1 + ): + msg = ( + f"{name} (moment {i}) :: sampled: {smpl_mom:+.2e}" + f" true: {dist_mom:+.2e} tested: " + ) + print(msg, end="", file=sys.stderr) + test = not jnp.isnan(dist_mom) + test &= not (jnp.allclose(dist_mom, 0.) and i > 1) + if i in moments_tol and test: + assert_allclose( + dist_mom, smpl_mom, + **dict(zip(("rtol", "atol"), moments_tol[i])) + ) + print("✓", file=sys.stderr) + else: + print("✗", file=sys.stderr) + + +if __name__ == "__main__": + for d in dists: + test_moment_consistency(d, plot=True) diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py new file mode 100644 index 0000000000000000000000000000000000000000..219951905c4d65ed738a74dfddd488ecd5169e03 --- /dev/null +++ b/test/test_re/test_hmc_hashes.py @@ -0,0 +1,90 @@ +import sys + +from jax import numpy as jnp +from jax.config import config as jax_config +from numpy import ndarray + +import nifty8.re as jft + + +NDARRAY_TYPE = [ndarray] + +try: + from jax.numpy import ndarray as jndarray + + NDARRAY_TYPE.append(jndarray) +except ImportError: + pass + +NDARRAY_TYPE = tuple(NDARRAY_TYPE) + + +def _json_serialize(obj): + if isinstance(obj, NDARRAY_TYPE): + return obj.tolist() + raise TypeError(f"unknown type {type(obj)}") + + +def hashit(obj, n_chars=8) -> str: + """Get first `n_chars` characters of Blake2B hash of `obj`.""" + import hashlib + import json + + return hashlib.blake2b( + bytes(json.dumps(obj, default=_json_serialize), "utf-8") + ).hexdigest()[:n_chars] + + +def test_hmc_hash(): + """Test sapmler output against known hash from previous commits.""" + x0 = jnp.array([0.1, 1.223], dtype=jnp.float32) + sampler = jft.HMCChain( + potential_energy=lambda x: jnp.sum(x**2), + inverse_mass_matrix=1., + position_proto=x0, + step_size=0.193, + num_steps=100, + max_energy_difference=1. + ) + chain, (key, pos) = sampler.generate_n_samples( + key=42, initial_position=x0, num_samples=1000, save_intermediates=True + ) + assert chain.divergences.sum() == 0 + accepted = chain.trees.accepted + results = (pos, key, chain.samples, accepted) + results_hash = hashit(results, n_chars=20) + print(f"full hash: {results_hash}", file=sys.stderr) + old_hash = "3d665689f809a98c81b3" + assert results_hash == old_hash + + +def test_nuts_hash(): + """Test sapmler output against known hash from previous commits.""" + jax_config.update("jax_enable_x64", False) + + x0 = jnp.array([0.1, 1.223], dtype=jnp.float32) + sampler = jft.NUTSChain( + potential_energy=lambda x: jnp.sum(x**2), + inverse_mass_matrix=1., + position_proto=x0, + step_size=0.193, + max_tree_depth=10, + bias_transition=False, + max_energy_difference=1. + ) + chain, (key, pos) = sampler.generate_n_samples( + key=42, initial_position=x0, num_samples=1000, save_intermediates=False + ) + assert chain.divergences.sum() == 0 + results = (pos, key, chain.samples) + results_hash = hashit(results, n_chars=20) + print(f"full hash: {results_hash}", file=sys.stderr) + old_hash = "8043850d7249acb77b26" + assert results_hash == old_hash + + jax_config.update("jax_enable_x64", True) + + +if __name__ == "__main__": + test_hmc_hash() + test_nuts_hash() diff --git a/test/test_re/test_hmc_leapfrog.py b/test/test_re/test_hmc_leapfrog.py new file mode 100644 index 0000000000000000000000000000000000000000..8062d2fc97e7374910edb4ff6970c171c0d3b27d --- /dev/null +++ b/test/test_re/test_hmc_leapfrog.py @@ -0,0 +1,89 @@ +import pytest +import sys +from jax import grad +from jax import numpy as jnp +from numpy.testing import assert_allclose + +import nifty8.re as jft + +pmp = pytest.mark.parametrize + +pot_and_tol = ( + ( + lambda q: jnp. + sum(q.T @ jnp.linalg.inv(jnp.array([[1, 0.95], [0.95, 1]])) @ q / 2.), + 0.2 + ), (lambda q: -1 / jnp.linalg.norm(q), 2e-2) +) + + +@pmp("potential_energy, rtol", pot_and_tol) +def test_leapfrog_energy_conservation(potential_energy, rtol): + dims = (2, ) + mass_matrix = jnp.ones(shape=dims) + kinetic_energy = lambda p: jnp.sum(p**2 / mass_matrix / 2.) + + potential_energy_gradient = grad(potential_energy) + positions = [jnp.array([-1.5, -1.55])] + momenta = [jnp.array([-1, 1])] + for _ in range(25): + new_qp = jft.hmc.leapfrog_step( + qp=jft.hmc.QP(position=positions[-1], momentum=momenta[-1]), + potential_energy_gradient=potential_energy_gradient, + kinetic_energy_gradient=lambda x, y: x * y, + step_size=0.25, + inverse_mass_matrix=1. / mass_matrix + ) + positions.append(new_qp.position) + momenta.append(new_qp.momentum) + + potential_energies = list(map(potential_energy, positions)) + kinetic_energies = list(map(kinetic_energy, momenta)) + + jnp.set_printoptions(precision=2) + for q, p, e_kin, e_pot in zip( + positions, momenta, potential_energies, kinetic_energies + ): + msg = ( + f"q: {q}; p: {p}" + f"\nE_tot: {e_pot+e_kin:.2e}; E_pot: {e_pot:.2e}; E_kin: {e_kin:.2e}" + ) + print(msg, file=sys.stderr) + + old_energy_tot = potential_energies[0] + kinetic_energies[0] + new_energy_tot = potential_energies[-1] + kinetic_energies[-1] + assert_allclose(old_energy_tot, new_energy_tot, rtol=rtol) + + return positions, momenta, kinetic_energies, potential_energies + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + qs, ps, e_kins, e_pots = test_leapfrog_energy_conservation(*pot_and_tol[0]) + positions = jnp.array(qs) + momenta = jnp.array(ps) + kinetic_energies = jnp.array(e_kins) + potential_energies = jnp.array(e_pots) + + # Position Coordinates + plt.plot(positions[:, 0], positions[:, 1]) + plt.xlabel("position[:,0]") + plt.ylabel("position[:,1]") + plt.show() + + # Momentum coordinates + plt.plot(momenta[:, 0], momenta[:, 1]) + plt.xlabel("momenta[:,0]") + plt.ylabel("momenta[:,1]") + plt.show() + + # Value of Hamiltonian + # does not look exactly the same as in Neal (2011) unfortunately! + plt.plot(kinetic_energies, label='kin') + plt.plot(potential_energies, label='pot') + plt.plot(kinetic_energies + potential_energies, label='total') + plt.xlabel('time') + plt.ylabel('energy') + plt.legend() + plt.show() diff --git a/test/test_re/test_hmc_pytree.py b/test/test_re/test_hmc_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6f1e3c89e3fc1f86819834970cac3506075003 --- /dev/null +++ b/test/test_re/test_hmc_pytree.py @@ -0,0 +1,75 @@ +from functools import partial +from jax import numpy as jnp +from jax.tree_util import tree_leaves +from numpy.testing import assert_array_equal + +import nifty8.re as jft + + +def test_hmc_pytree(): + """Test sapmler output against known hash from previous commits.""" + initial_position = jnp.array([0.31415, 2.71828]) + + sampler_init = partial( + jft.HMCChain, + potential_energy=jft.sum_of_squares, + inverse_mass_matrix=1., + step_size=0.193, + num_steps=100 + ) + + initial_position_py = jft.Field(({"lvl0": initial_position}, )) + smpl_w_pytree = sampler_init(position_proto=initial_position_py + ).generate_n_samples( + key=321, + initial_position=initial_position_py, + num_samples=1000 + ) + smpl_wo_pytree = sampler_init(position_proto=initial_position + ).generate_n_samples( + key=321, + initial_position=initial_position, + num_samples=1000 + ) + + ts_w, ts_wo = tree_leaves(smpl_w_pytree), tree_leaves(smpl_wo_pytree) + assert len(ts_w) == len(ts_wo) + for w, wo in zip(ts_w, ts_wo): + assert_array_equal(w, wo) + + +def test_nuts_pytree(): + """Test sapmler output against known hash from previous commits.""" + initial_position = jnp.array([0.31415, 2.71828]) + + sampler_init = partial( + jft.NUTSChain, + potential_energy=jft.sum_of_squares, + inverse_mass_matrix=1., + step_size=0.193, + max_tree_depth=10, + ) + + initial_position_py = jft.Field(({"lvl0": initial_position}, )) + smpl_w_pytree = sampler_init(position_proto=initial_position_py + ).generate_n_samples( + key=323, + initial_position=initial_position_py, + num_samples=1000 + ) + smpl_wo_pytree = sampler_init(position_proto=initial_position + ).generate_n_samples( + key=323, + initial_position=initial_position, + num_samples=1000 + ) + + ts_w, ts_wo = tree_leaves(smpl_w_pytree), tree_leaves(smpl_wo_pytree) + assert len(ts_w) == len(ts_wo) + for w, wo in zip(ts_w, ts_wo): + assert_array_equal(w, wo) + + +if __name__ == "__main__": + test_hmc_pytree() + test_nuts_pytree() diff --git a/test/test_re/test_lanczos.py b/test/test_re/test_lanczos.py new file mode 100644 index 0000000000000000000000000000000000000000..abc4edbac7510c90ccb281abba882bfd259ebae0 --- /dev/null +++ b/test/test_re/test_lanczos.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys +from jax import random +import jax.numpy as jnp +import numpy as np +from numpy.testing import assert_allclose +import pytest +from scipy.spatial import distance_matrix + +import nifty8.re as jft + +pmp = pytest.mark.parametrize + + +def matern_kernel(distance, scale, cutoff, dof): + from jax.scipy.special import gammaln + from scipy.special import kv + + reg_dist = jnp.sqrt(2 * dof) * distance / cutoff + cov = scale**2 * 2**(1 - dof) / jnp.exp( + gammaln(dof) + ) * (reg_dist)**dof * kv(dof, reg_dist) + # NOTE, this is not safe for differentiating because `cov` still may + # contain NaNs + return jnp.where(distance < 1e-8 * cutoff, scale**2, cov) + + +from operator import matmul + + +@pmp("seed", tuple(range(12, 44, 5))) +@pmp("shape0", (128, 64)) +def test_lanczos_tridiag(seed, shape0): + rng = np.random.default_rng(seed) + rng_key = random.PRNGKey(rng.integers(12, 42)) + + m = rng.normal(size=(shape0, ) * 2) + m = m @ m.T # ensure positive-definiteness + + tridiag, vecs = jft.lanczos.lanczos_tridiag( + partial(matmul, m), jft.ShapeWithDtype((shape0, )), shape0, rng_key + ) + m_est = vecs.T @ tridiag @ vecs + + np.testing.assert_allclose(m_est, m, atol=1e-13, rtol=1e-13) + + +@pmp("seed", tuple(range(12, 44, 5))) +@pmp("shape0", (128, 64)) +def test_stochastic_lq_logdet(seed, shape0, lq_order=15, n_lq_samples=10): + rng = np.random.default_rng(seed) + rng_key = random.PRNGKey(rng.integers(12, 42)) + + c = np.exp(3 + rng.normal()) + s = np.exp(rng.normal()) + + p = np.logspace(np.log(0.1 * c), np.log(1e+2 * c), num=shape0 - 1) + p = np.concatenate(([0], p)).reshape(-1, 1) + + m = jnp.asarray( + matern_kernel(distance_matrix(p, p), cutoff=c, scale=s, dof=2.5) + ) + + _, logdet = jnp.linalg.slogdet(m) + logdet_est = jft.stochastic_lq_logdet(m, lq_order, n_lq_samples, rng_key) + assert_allclose(logdet_est, logdet, rtol=2., atol=20.) + print(f"{logdet=} :: {logdet_est=}", file=sys.stderr) diff --git a/test/test_re/test_ncg.py b/test/test_re/test_ncg.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd9737d189c1fa66c4ca5a65582d1c1a839ce52 --- /dev/null +++ b/test/test_re/test_ncg.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 + +import sys + +from jax import random, value_and_grad +import jax.numpy as jnp +from numpy.testing import assert_allclose +import pytest + +import nifty8.re as jft + +pmp = pytest.mark.parametrize + + +def rosenbrock(np): + def func(x): + return jnp.sum(100. * jnp.diff(x)**2 + (1. - x[:-1])**2) + + return func + + +def himmelblau(np): + def func(p): + x, y = p + return (x**2 + y - 11.)**2 + (x + y**2 - 7.)**2 + + return func + + +def matyas(np): + def func(p): + x, y = p + return 0.26 * (x**2 + y**2) - 0.48 * x * y + + return func + + +def eggholder(np): + def func(p): + x, y = p + return -(y + 47) * jnp.sin( + jnp.sqrt(jnp.abs(x / 2. + y + 47.)) + ) - x * jnp.sin(jnp.sqrt(jnp.abs(x - (y + 47.)))) + + return func + + +def test_ncg_for_pytree(): + pos = jft.Field( + [ + jnp.array(0., dtype=jnp.float32), + (jnp.array(3., dtype=jnp.float32), ), { + "a": jnp.array(5., dtype=jnp.float32) + } + ] + ) + getters = (lambda x: x[0], lambda x: x[1][0], lambda x: x[2]["a"]) + tgt = [-10., 1., 2.] + met = [10., 40., 2] + + def model(p): + losses = [] + for i, get in enumerate(getters): + losses.append((get(p) - tgt[i])**2 * met[i]) + return jnp.sum(jnp.array(losses)) + + def metric(p, tan): + m = [] + m.append(tan[0] * met[0]) + m.append((tan[1][0] * met[1], )) + m.append({"a": tan[2]["a"] * met[2]}) + return jft.Field(m) + + res = jft.newton_cg( + fun_and_grad=value_and_grad(model), + x0=pos, + hessp=metric, + maxiter=10, + absdelta=1e-6 + ) + for i, get in enumerate(getters): + assert_allclose(get(res), tgt[i], atol=1e-6, rtol=1e-5) + + +@pmp("seed", (3637, 12, 42)) +def test_ncg(seed): + key = random.PRNGKey(seed) + x = random.normal(key, shape=(3, )) + diag = jnp.array([1., 2., 3.]) + met = lambda y, t: t / diag + val_and_grad = lambda y: ( + jnp.sum(y**2 / diag) / 2 - jnp.dot(x, y), y / diag - x + ) + + res = jft.newton_cg( + fun_and_grad=val_and_grad, + x0=x, + hessp=met, + maxiter=20, + absdelta=1e-6, + name='N' + ) + assert_allclose(res, diag * x, rtol=1e-4, atol=1e-4) + + +@pmp("seed", (3637, 12, 42)) +@pmp("cg", (jft.cg, jft.static_cg)) +def test_cg(seed, cg): + key = random.PRNGKey(seed) + sk = random.split(key, 2) + x = random.normal(sk[0], shape=(3, )) + # Avoid poorly conditioned matrices by shifting the elements from zero + diag = 6. + random.normal(sk[1], shape=(3, )) + mat = lambda x: x / diag + + res, _ = cg(mat, x, resnorm=1e-5, absdelta=1e-5) + assert_allclose(res, diag * x, rtol=1e-4, atol=1e-4) + + +@pmp("seed", (3637, 12, 42)) +@pmp("cg", (jft.cg, jft.static_cg)) +def test_cg_non_pos_def_failure(seed, cg): + key = random.PRNGKey(seed) + sk = random.split(key, 2) + + x = random.normal(sk[0], shape=(4, )) + # Purposely produce a non-positive definite matrix + diag = jnp.concatenate( + (jnp.array([-1]), 6. + random.normal(sk[1], shape=(3, ))) + ) + mat = lambda x: x / diag + + with pytest.raises(ValueError): + _, info = cg(mat, x, resnorm=1e-5, absdelta=1e-5) + if info < 0: + raise ValueError() + + +@pmp("seed", (3637, 12, 42)) +def test_cg_steihaug(seed): + key = random.PRNGKey(seed) + sk = random.split(key, 2) + x = random.normal(sk[0], shape=(3, )) + # Avoid poorly conditioned matrices by shifting the elements from zero + diag = 6. + random.normal(sk[1], shape=(3, )) + mat = lambda x: x / diag + + # Note, the solution to the subproblem with infinite trust radius is the CG + # but with the opposite sign + res = jft.conjugate_gradient._cg_steihaug_subproblem( + jnp.nan, -x, mat, resnorm=1e-6, trust_radius=jnp.inf + ) + assert_allclose(res.step, diag * x, rtol=1e-4, atol=1e-4) + + +@pmp("seed", (3637, 12, 42)) +@pmp("size", (5, 9, 14)) +def test_cg_steihaug_vs_cg_consistency(seed, size): + key = random.PRNGKey(seed) + sk = random.split(key, 2) + + x = random.normal(sk[0], shape=(size, )) + # Avoid poorly conditioned matrices by shifting the elements from zero + mat_val = 6. + random.normal(sk[1], shape=(size, size)) + mat_val = mat_val @ mat_val.T # Construct a symmetric matrix + mat = lambda x: mat_val @ x + + # Note, the solution to the subproblem with infinite trust radius is the CG + # but with the opposite sign + for i in range(4): + print(f"Iteratoin {i:02d}", file=sys.stderr) + res_cgs = jft.conjugate_gradient._cg_steihaug_subproblem( + jnp.nan, + -x, + mat, + resnorm=1e-6, + trust_radius=jnp.inf, + miniter=i, + maxiter=i + ) + res_cg_plain, _ = jft.conjugate_gradient.cg( + mat, x, resnorm=1e-6, miniter=i, maxiter=i + ) + assert_allclose(res_cgs.step, res_cg_plain, rtol=1e-4, atol=1e-5) + + +@pmp( + "fun_and_init", ( + (rosenbrock, jnp.zeros(2)), (himmelblau, jnp.zeros(2)), + (matyas, jnp.ones(2) * 6.), (eggholder, jnp.ones(2) * 100.) + ) +) +@pmp("maxiter", (jnp.inf, None)) +def test_minimize(fun_and_init, maxiter): + from scipy.optimize import minimize as opt_minimize + from jax import grad, hessian + + func, x0 = fun_and_init + + def jft_minimize(x0): + result = jft.minimize( + func(jnp), + x0, + method='trust-ncg', + options=dict( + maxiter=maxiter, + energy_reduction_factor=None, + gtol=1e-6, + initial_trust_radius=1., + max_trust_radius=1000. + ), + ) + return result.x + + def scp_minimize(x0): + # Use JAX primitives to take derivates + fun = func(jnp) + result = opt_minimize( + fun, x0, jac=grad(fun), hess=hessian(fun), method='trust-ncg' + ) + return result.x + + jax_res = jft_minimize(x0) + scipy_res = scp_minimize(x0) + assert_allclose(scipy_res, jax_res, rtol=2e-6, atol=2e-5) + + +if __name__ == "__main__": + test_ncg_for_pytree() diff --git a/test/test_re/test_refine.py b/test/test_re/test_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad0075073a0a44c1109617d34116fc4f3a83f29 --- /dev/null +++ b/test/test_re/test_refine.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial +import sys + +import jax +from jax import random +import jax.numpy as jnp +from jax.tree_util import Partial +import numpy as np +from numpy.testing import assert_allclose +import pytest +from scipy.spatial import distance_matrix + +import nifty8.re as jft +from nifty8.re import refine, refine_chart + +pmp = pytest.mark.parametrize + + +def matern_kernel(distance, scale, cutoff, dof): + from jax.scipy.special import gammaln + from scipy.special import kv + + reg_dist = jnp.sqrt(2 * dof) * distance / cutoff + return scale**2 * 2**(1 - dof) / jnp.exp( + gammaln(dof) + ) * (reg_dist)**dof * kv(dof, reg_dist) + + +scale, cutoff, dof = 1., 80., 3 / 2 + +x = jnp.logspace(-6, 11, base=jnp.e, num=int(1e+5)) +y = matern_kernel(x, scale, cutoff, dof) +y = jnp.nan_to_num(y, nan=0.) +kernel = Partial(jnp.interp, xp=x, fp=y) +inv_kernel = Partial(jnp.interp, xp=y, fp=x) + + +@pmp("dist", (10., 20., 30., 1e+3)) +def test_refinement_matrices_1d(dist, kernel=kernel): + cov_from_loc = refine._get_cov_from_loc(kernel=kernel) + + coarse_coord = dist * jnp.array([0., 1., 2.]) + fine_coord = coarse_coord[tuple( + jnp.array(coarse_coord.shape) // 2 + )] + (jnp.diff(coarse_coord) / jnp.array([-4., 4.])) + cov_ff = cov_from_loc(fine_coord, fine_coord) + cov_fc = cov_from_loc(fine_coord, coarse_coord) + cov_cc_inv = jnp.linalg.inv(cov_from_loc(coarse_coord, coarse_coord)) + + fine_kernel = cov_ff - cov_fc @ cov_cc_inv @ cov_fc.T + fine_kernel_sqrt_diy = jnp.linalg.cholesky(fine_kernel) + olf_diy = cov_fc @ cov_cc_inv + + olf, fine_kernel_sqrt = refine.layer_refinement_matrices(dist, kernel) + + assert_allclose(olf, olf_diy) + assert_allclose(fine_kernel_sqrt, fine_kernel_sqrt_diy) + + +@pmp("seed", (12, 42, 43, 45)) +@pmp("dist", (10., 20., 30., 1e+3)) +def test_refinement_1d(seed, dist, kernel=kernel): + rng = np.random.default_rng(seed) + + refs = ( + refine.refine_conv, refine.refine_conv_general, refine.refine_loop, + refine.refine_vmap, refine.refine_loop, refine.refine_slice + ) + cov_from_loc = refine._get_cov_from_loc(kernel=kernel) + olf, fine_kernel_sqrt = refine.layer_refinement_matrices(dist, kernel) + + main_coord = jnp.linspace(0., 1000., 50) + cov_sqrt = jnp.linalg.cholesky(cov_from_loc(main_coord, main_coord)) + lvl0 = cov_sqrt @ rng.normal(size=main_coord.shape) + lvl1_exc = rng.normal(size=(2 * (lvl0.size - 2), )) + + fine_reference = refine.refine(lvl0, lvl1_exc, olf, fine_kernel_sqrt) + eps = jnp.finfo(lvl0.dtype.type).eps + aallclose = partial( + assert_allclose, desired=fine_reference, rtol=6 * eps, atol=60 * eps + ) + for ref in refs: + print(f"testing {ref.__name__}", file=sys.stderr) + aallclose(ref(lvl0, lvl1_exc, olf, fine_kernel_sqrt)) + + +@pmp("seed", (12, 42)) +@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4))) +@pmp("_coarse_size", (3, 5)) +@pmp("_fine_size", (2, 4)) +@pmp("_fine_strategy", ("jump", "extend")) +def test_refinement_nd_cross_consistency( + seed, dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel +): + ndim = len(dist) if hasattr(dist, "__len__") else 1 + min_shape = (12, ) * ndim + depth = 1 + refs = (refine.refine_conv_general, refine.refine_slice) + kwargs = { + "_coarse_size": _coarse_size, + "_fine_size": _fine_size, + "_fine_strategy": _fine_strategy + } + + chart = refine_chart.CoordinateChart( + min_shape, depth=depth, distances=dist, **kwargs + ) + rfm = refine_chart.RefinementField(chart).matrices(kernel) + xi = jft.random_like( + random.PRNGKey(seed), + refine_chart.RefinementField(chart).shapewithdtype + ) + + cf = partial(refine_chart.RefinementField.apply, chart=chart, kernel=rfm) + fine_reference = cf(xi) + eps = jnp.finfo(fine_reference.dtype.type).eps + aallclose = partial( + assert_allclose, desired=fine_reference, rtol=6 * eps, atol=60 * eps + ) + for ref in refs: + print(f"testing {ref.__name__}", file=sys.stderr) + aallclose(cf(xi, _refine=ref)) + + +@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4))) +def test_refinement_fine_strategy_basic_consistency(dist, kernel=kernel): + olf_j, ks_j = refine.layer_refinement_matrices( + dist, kernel=kernel, _fine_size=2, _fine_strategy="jump" + ) + olf_e, ks_e = refine.layer_refinement_matrices( + dist, kernel=kernel, _fine_size=2, _fine_strategy="extend" + ) + + assert_allclose(olf_j, olf_e, rtol=1e-13, atol=0.) + assert_allclose(ks_j, ks_e, rtol=1e-13, atol=0.) + + shape0 = (12, ) * len(dist) if isinstance(dist, tuple) else (12, ) + depth = 2 + olfs_j, (csq0_j, kss_j) = refine.refinement_matrices( + shape0, depth, dist, kernel=kernel, _fine_strategy="jump" + ) + olfs_e, (csq0_e, kss_e) = refine.refinement_matrices( + shape0, depth, dist, kernel=kernel, _fine_strategy="extend" + ) + + assert_allclose(olfs_j, olfs_e, rtol=1e-13, atol=0.) + assert_allclose(kss_j, kss_e, rtol=1e-13, atol=0.) + assert_allclose(csq0_j, csq0_e, rtol=1e-13, atol=0.) + + +@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4))) +@pmp("_coarse_size", (3, 5)) +@pmp("_fine_size", (2, 4)) +@pmp("_fine_strategy", ("jump", "extend")) +def test_refinement_covariance( + dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel +): + distances0 = np.atleast_1d(dist) + ndim = len(distances0) + + cf = refine_chart.RefinementField( + shape0=(_coarse_size, ) * ndim, + depth=1, + _coarse_size=_coarse_size, + _fine_size=_fine_size, + _fine_strategy=_fine_strategy, + distances0=distances0, + kernel=kernel + ) + exc_shp = [ + jft.ShapeWithDtype((_coarse_size, ) * ndim), + jft.ShapeWithDtype((_fine_size, ) * ndim) + ] + cf_shp = jax.eval_shape(cf, exc_shp) + assert cf_shp.shape == (_fine_size, ) * ndim + + probe = jnp.zeros(cf_shp.shape) + indices = np.indices(cf_shp.shape).reshape(ndim, -1) + # Work around jax.linear_transpose NotImplementedError + _, cf_T = jax.vjp(cf, jft.zeros_like(exc_shp)) + cf_cf_T = lambda x: cf(*cf_T(x)) + cov_empirical = jax.vmap( + lambda idx: cf_cf_T(probe.at[tuple(idx)].set(1.)).ravel(), + in_axes=1, + out_axes=-1 + )(indices) + + pos = np.mgrid[tuple(slice(s) for s in cf_shp.shape)].astype(float) + if _fine_strategy == "jump": + pos *= distances0.reshape((-1, ) + (1, ) * ndim) / _fine_size + elif _fine_strategy == "extend": + pos *= distances0.reshape((-1, ) + (1, ) * ndim) / 2 + else: + raise AssertionError(f"invalid `_fine_strategy`; {_fine_strategy}") + pos = jnp.moveaxis(pos, 0, -1) + p = pos.reshape(-1, ndim) + dist_mat = distance_matrix(p, p) + cov_truth = kernel(dist_mat) + + assert_allclose(cov_empirical, cov_truth, rtol=1e-14, atol=1e-15) + + +@pmp("seed", (12, 42, 43, 45)) +@pmp("n_dim", (1, 2, 3, 4, 5)) +def test_refinement_nd_shape(seed, n_dim, kernel=kernel): + rng = np.random.default_rng(seed) + + distances = np.exp(rng.normal(size=(n_dim, ))) + cov_from_loc = refine._get_cov_from_loc(kernel=kernel) + olf, fine_kernel_sqrt = refine.layer_refinement_matrices(distances, kernel) + + shp_i = 5 + gc = distances.reshape(n_dim, 1) * jnp.linspace(0., 1000., shp_i) + gc = jnp.stack(jnp.meshgrid(*gc, indexing="ij"), axis=-1).reshape(-1, n_dim) + cov_sqrt = jnp.linalg.cholesky(cov_from_loc(gc, gc)) + lvl0 = (cov_sqrt @ rng.normal(size=gc.shape[0])).reshape((shp_i, ) * n_dim) + lvl1_exc = rng.normal(size=tuple(n - 2 for n in lvl0.shape) + (2**n_dim, )) + + fine_reference = refine.refine(lvl0, lvl1_exc, olf, fine_kernel_sqrt) + assert fine_reference.shape == tuple((2 * (shp_i - 2), ) * n_dim) + + +@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4))) +@pmp("_coarse_size", (3, 5)) +@pmp("_fine_size", (2, 4)) +@pmp("_fine_strategy", ("jump", "extend")) +def test_chart_pixel_refinement_matrices_consistency( + dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel +): + depth = 3 + distances = np.atleast_1d(dist) + kwargs = { + "_coarse_size": _coarse_size, + "_fine_size": _fine_size, + "_fine_strategy": _fine_strategy + } + + cc = refine_chart.CoordinateChart( + (12, ) * distances.size, depth=depth, distances=distances, **kwargs + ) + olf, ks = refine_chart.RefinementField(cc).matrices_at( + level=depth, pixel_index=(0, ) * distances.size, kernel=kernel + ) + olf_classical, ks_classical = refine.layer_refinement_matrices( + distances, kernel, **kwargs + ) + assert_allclose(olf, olf_classical, atol=1e-14, rtol=1e-14) + assert_allclose(ks, ks_classical, atol=1e-14, rtol=1e-14) + + +@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4))) +@pmp("_coarse_size", (3, 5)) +@pmp("_fine_size", (2, 4)) +@pmp("_fine_strategy", ("jump", "extend")) +def test_chart_refinement_matrices_consistency( + dist, _coarse_size, _fine_size, _fine_strategy, kernel=kernel +): + depth = 3 + distances = np.atleast_1d(dist) + ndim = distances.size + kwargs = { + "_coarse_size": _coarse_size, + "_fine_size": _fine_size, + "_fine_strategy": _fine_strategy + } + + cc = refine_chart.CoordinateChart( + (12, ) * ndim, depth=depth, distances=distances, **kwargs + ) + refinement = refine_chart.RefinementField(cc).matrices(kernel=kernel) + + cc_irreg = refine_chart.CoordinateChart( + shape0=cc.shape0, + depth=depth, + distances=distances, + irregular_axes=tuple(range(ndim)), + **kwargs + ) + refinement_irreg = refine_chart.RefinementField(cc_irreg).matrices( + kernel=kernel + ) + + _, (cov_sqrt0, _) = refine.refinement_matrices( + cc.shape0, 0, cc.distances0, kernel, **kwargs + ) + + aallclose = partial(assert_allclose, rtol=1e-14, atol=1e-13) + aallclose(refinement.cov_sqrt0, cov_sqrt0) + aallclose(refinement_irreg.cov_sqrt0, cov_sqrt0) + + for lvl in range(depth): + olf, ks = refinement.filter[lvl], refinement.propagator_sqrt[lvl] + olf_irreg, ks_irreg = refinement_irreg.filter[ + lvl], refinement_irreg.propagator_sqrt[lvl] + + if _fine_strategy == "jump": + distances_lvl = cc.distances0 / _fine_size**lvl + elif _fine_strategy == "extend": + distances_lvl = cc.distances0 / 2**lvl + else: + raise AssertionError() + olf_classical, ks_classical = refine.layer_refinement_matrices( + distances_lvl, kernel, **kwargs + ) + + aallclose(olf.squeeze(), olf_classical) + aallclose(ks.squeeze(), ks_classical) + + olf_d = np.diff( + olf_irreg.reshape((-1, ) + olf_irreg.shape[-2:]), axis=0 + ) + ks_d = np.diff(ks_irreg.reshape((-1, ) + ks_irreg.shape[-2:]), axis=0) + aallclose(olf_d, 0.) + aallclose(ks_d, 0.) + aallclose(olf_irreg[(0, ) * ndim], olf_classical) + aallclose(ks_irreg[(0, ) * ndim], ks_classical) + + +@pmp("seed", (12, )) +@pmp("dist", (60., 1e+3, (80., 80.), (40., 90.), (1e+2, 1e+3, 1e+4))) +@pmp("_coarse_size", (3, 5)) +@pmp("_fine_size", (2, 4)) +@pmp("_fine_strategy", ("jump", "extend")) +@pmp("_refine", (refine.refine_conv_general, refine.refine_slice)) +def test_refinement_irregular_regular_consistency( + seed, + dist, + _coarse_size, + _fine_size, + _fine_strategy, + _refine, + kernel=kernel +): + depth = 1 + distances = np.atleast_1d(dist) + ndim = distances.size + kwargs = { + "_coarse_size": _coarse_size, + "_fine_size": _fine_size, + "_fine_strategy": _fine_strategy + } + + cc = refine_chart.RefinementField( + shape0=(2 * _coarse_size, ) * ndim, + depth=depth, + distances=distances, + **kwargs + ) + refinement = cc.matrices(kernel=kernel) + + cc_irreg = refine_chart.RefinementField( + shape0=cc.chart.shape0, + depth=depth, + distances=distances, + irregular_axes=tuple(range(ndim)), + **kwargs + ) + refinement_irreg = cc_irreg.matrices(kernel=kernel) + + rng = np.random.default_rng(seed) + exc_swd = cc.shapewithdtype[-1] + fn1 = rng.normal(size=cc.chart.shape_at(depth - 1)) + exc = rng.normal(size=exc_swd.shape) + + refined = _refine( + fn1, exc, refinement.filter[-1], refinement.propagator_sqrt[-1], + **kwargs + ) + refined_irreg = _refine( + fn1, exc, refinement_irreg.filter[-1], + refinement_irreg.propagator_sqrt[-1], **kwargs + ) + assert_allclose(refined_irreg, refined, rtol=1e-14, atol=1e-13) + + +if __name__ == "__main__": + test_refinement_matrices_1d(5.) + test_refinement_1d(42, 10.) diff --git a/test/test_re/test_refine_util.py b/test/test_re/test_refine_util.py new file mode 100644 index 0000000000000000000000000000000000000000..36df85c0d076b512c1c0d541bf3ae8a3e8b94bd6 --- /dev/null +++ b/test/test_re/test_refine_util.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause + +from functools import partial + +import jax +import numpy as np +import pytest + +from nifty8.re import refine_chart, refine_util + +pmp = pytest.mark.parametrize + + +@pmp("shape0", ((16, ), (13, 15), (11, 12, 13))) +@pmp("depth", (1, 2)) +@pmp("_coarse_size", (3, 5, 7)) +@pmp("_fine_size", (2, 4, 6)) +@pmp("_fine_strategy", ("jump", "extend")) +def test_shape_translations( + shape0, depth, _coarse_size, _fine_size, _fine_strategy +): + kwargs = { + "_coarse_size": _coarse_size, + "_fine_size": _fine_size, + "_fine_strategy": _fine_strategy + } + + def cf(shape0, xi): + chart = refine_chart.CoordinateChart( + shape0=shape0, + depth=depth, + distances0=(1., ) * len(shape0), + **kwargs + ) + return refine_chart.RefinementField.apply( + xi, chart=chart, kernel=lambda x: x + ) + + dom = refine_util.get_refinement_shapewithdtype(shape0, depth, **kwargs) + tgt = jax.eval_shape(partial(cf, shape0), dom) + tgt_pred_shp = refine_util.coarse2fine_shape(shape0, depth, **kwargs) + assert tgt_pred_shp == tgt.shape + assert dom[-1].size == tgt.size == np.prod(tgt_pred_shp) + + shape0_pred = refine_util.fine2coarse_shape(tgt.shape, depth, **kwargs) + dom_pred = refine_util.get_refinement_shapewithdtype( + shape0_pred, depth, **kwargs + ) + tgt_pred = jax.eval_shape(partial(cf, shape0_pred), dom_pred) + + assert tgt.shape == tgt_pred.shape + if _fine_strategy == "jump": + assert shape0_pred == shape0 + else: + assert _fine_strategy == "extend" + assert all(s0_p <= s0 for s0_p, s0 in zip(shape0_pred, shape0)) + + +@pmp("seed", (42, 45)) +def test_gauss_kl(seed, n_resamples=100): + rng = np.random.default_rng(seed) + for _ in range(n_resamples): + d = max(rng.poisson(4), 1) + m_t = rng.normal(size=(d, d)) + m_t = m_t @ m_t.T + scl = rng.lognormal(2., 3.) + + np.testing.assert_allclose( + refine_util.gauss_kl(m_t, m_t), 0., atol=1e-11 + ) + kl_rhs_scl = 0.5 * d * (np.log(scl) + 1. / scl - 1.) + np.testing.assert_allclose( + kl_rhs_scl, refine_util.gauss_kl(m_t, scl * m_t), rtol=1e-11 + ) + kl_lhs_scl = 0.5 * d * (-np.log(scl) + scl - 1.) + np.testing.assert_allclose( + kl_lhs_scl, refine_util.gauss_kl(scl * m_t, m_t), rtol=1e-10 + ) diff --git a/test/test_re/test_stats_distributions.py b/test/test_re/test_stats_distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e021f84d5a903f6c23c59a14f8ad7edb6de41 --- /dev/null +++ b/test/test_re/test_stats_distributions.py @@ -0,0 +1,41 @@ +import numpy as np +from numpy.testing import assert_allclose +import pytest + +import nifty8.re as jft + +pmp = pytest.mark.parametrize + + +@pmp("a", (3., 1.5, 4.)) +@pmp("scale", (2., 4.)) +@pmp("loc", (2., 4., 0.)) +@pmp("seed", (42, 43)) +def test_invgamma_roundtrip(a, scale, loc, seed, step=1e-1): + rng = np.random.default_rng(seed) + + n_samples = int(1e+4) + n_rvs = rng.normal(loc=0., scale=2., size=(n_samples, )) + n_rvs = n_rvs.clip(-5.2, 5.2) + + pr = jft.invgamma_prior(a, scale, loc=loc, step=step) + ipr = jft.invgamma_invprior(a, scale, loc=loc, step=step) + + n_roundtrip = ipr(pr(n_rvs)) + assert_allclose(n_roundtrip, n_rvs, rtol=1e-4, atol=1e-3) + + +@pmp("mean", (2., 4.)) +@pmp("std", (2., 4.)) +@pmp("seed", (42, 43)) +def test_lognormal_roundtrip(mean, std, seed): + rng = np.random.default_rng(seed) + + n_samples = int(1e+4) + n_rvs = rng.normal(loc=0., scale=2., size=(n_samples, )) + + pr = jft.lognormal_prior(mean, std) + ipr = jft.lognormal_invprior(mean, std) + + n_roundtrip = ipr(pr(n_rvs)) + assert_allclose(n_roundtrip, n_rvs, rtol=1e-6, atol=1e-6)