Skip to content
Snippets Groups Projects
Commit b4a386be authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

Use metric sample for probing

Approximates tr(log(M)) using tr(log(T)) where T is the projection of M
into the krylov subspace K(M, v) where v is a sample from the metric M
(i.E. v = v_lh + v_pr where v_lh/v_pr are samples from the
likelihood/prior metric, respectively. In addition, the projected sample
is constructed by taking v_pr projecting out the subspace K(M,v) using
its eigen-basis. This ensures that both, the prior dominated part of v
and the part already covered by tr(log(T)) is projected out.

Original author: Philipp Frank
parent 58e0ca09
No related branches found
No related tags found
1 merge request!832Better Lanczos interface
......@@ -19,27 +19,42 @@ jax.config.update("jax_enable_x64", True)
# %%
def stochastic_lq_logdet(
def lanczos_logdet(
mat,
v,
order: int,
key,
*,
shape0=None,
dtype=None,
):
"""Computes a stochastic estimate of the log-determinate of a matrix using
the stochastic Lanczos quadrature algorithm.
"""Computes a stochastic estimate of the log-determinate of the Lanczos
decomposed matrix. This is not the same as applying the stochastic Lanczos
quadrature algorithm as it estimates the log-determinate for the
decomposition only.
"""
shape0 = shape0 if shape0 is not None else mat.shape[0]
mat = mat.__matmul__ if not hasattr(mat, "__call__") else mat
key = random.PRNGKey(key) if not isinstance(key, jnp.ndarray) else key
probe = random.normal(key, (shape0, ), dtype=dtype)
tridiags, vecs = jft.lanczos.lanczos_tridiag(mat, probe, order=order)
logdet = jft.lanczos.stochastic_logdet_from_lanczos(
tridiags.reshape((1, ) + tridiags.shape), shape0
tridiag, vecs = jft.lanczos.lanczos_tridiag(mat, v, order=order)
eig_vals = jnp.linalg.eigvalsh(tridiag)
return jnp.log(eig_vals).sum(), vecs
def _metric_sample(
hamiltonian: jft.StandardHamiltonian,
primals,
key,
):
if not isinstance(hamiltonian, jft.StandardHamiltonian):
te = f"`hamiltonian` of invalid type; got '{type(hamiltonian)}'"
raise TypeError(te)
subkey_nll, subkey_prr = random.split(key, 2)
nll_smpl = jft.kl.sample_likelihood(
hamiltonian.likelihood, primals, key=subkey_nll
)
return logdet, vecs
prr_inv_metric_smpl = jft.random_like(key=subkey_prr, primals=primals)
# One may transform any metric sample to a sample of the inverse
# metric by simply applying the inverse metric to it
prr_smpl = prr_inv_metric_smpl
met_smpl = nll_smpl + prr_smpl
return met_smpl, prr_smpl
def geomap(
......@@ -60,17 +75,18 @@ def geomap(
)
return o
key_lcz, key_smpls = random.split(key, 2)
logdet, vecs = stochastic_lq_logdet(
mat, order, key_lcz, shape0=p.size, dtype=p.dtype
)
probe, smpl = _metric_sample(hamiltonian, pos, key)
probe = flatten_util.ravel_pytree(probe)[0]
smpl = flatten_util.ravel_pytree(smpl)[0]
logdet, vecs = lanczos_logdet(mat, probe, order, shape0=p.size)
if sample_orthonormally is None:
if not sample_orthonormally:
energy = hamiltonian(pos)
smpl_orig, smpl = None, None
else:
smpl = random.normal(key_smpls, p.shape, dtype=p.dtype)
smpl_orig = smpl.copy()
#smpl = random.normal(smpl_key, p.shape)
smpl_orig = unflatten(smpl.copy())
# TODO: Pull into new lanczos method which computes orthoganlized smpls
# for vecs
ortho_smpl = vecs @ smpl
......@@ -155,9 +171,9 @@ pos = 1e-2 * pos_init.copy()
print("!!! HAM", ham(pos))
print("!!! metric", ham.metric(pos, pos) @ pos)
# This is 50 times slower in compile time than ham.metric
geomap_order = 5
geomap_order = 40
geomap_energy = geomap(
ham, geomap_order, subkey_geomap, sample_orthonormally=False
ham, geomap_order, subkey_geomap, sample_orthonormally=True
)
geomap_energy = jax.jit(geomap_energy, static_argnames=("return_aux", ))
......@@ -191,7 +207,7 @@ plt.show()
# %%
smpls_by_order = []
for i in range(1, geomap_order):
_, (_, s) = geomap(ham, i, subkey_geomap, sample_orthonormally=False)(
_, (_, s) = geomap(ham, i, subkey_geomap, sample_orthonormally=True)(
opt_state_geomap.x, return_aux=True
)
smpls_by_order += [s]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment