Skip to content
Snippets Groups Projects
Commit 92f59bae authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

Merge branch 'better_lanczos' into 'NIFTy_8'

Better Lanczos interface

See merge request !832
parents 1860917d cd349c61
No related branches found
No related tags found
1 merge request!832Better Lanczos interface
Pipeline #152880 passed
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
from functools import partial
from typing import Callable, Optional, Union
from typing import Callable, Optional, TypeVar, Union
import jax
from jax import numpy as jnp
......@@ -9,27 +8,30 @@ from jax import random
from .forest_util import ShapeWithDtype
from .disable_jax_control_flow import fori_loop
def lanczos_tridiag(
mat: Callable, shape_dtype_struct: ShapeWithDtype, order: int,
key: jnp.ndarray
):
V = TypeVar("V")
def lanczos_tridiag(mat: Callable[[V], V], v: V, order: int):
"""Compute the Lanczos decomposition into a tri-diagonal matrix and its
corresponding orthonormal projection matrix.
"""
tridiag = jnp.zeros((order, order), dtype=shape_dtype_struct.dtype)
vecs = jnp.zeros(
(order, ) + shape_dtype_struct.shape, dtype=shape_dtype_struct.dtype
)
swd = ShapeWithDtype.from_leave(v)
tridiag = jnp.zeros((order, order), dtype=swd.dtype)
vecs = jnp.zeros((order, ) + swd.shape, dtype=swd.dtype)
v = random.normal(key, shape=shape_dtype_struct.shape)
v = v / jnp.linalg.norm(v)
vecs = vecs.at[0].set(v)
# TODO
# * use `forest_util.dot` and `forest_util.norm` in favor of plain `jnp.dot`
# * remove all reshapes as they are unnecessary
# Zeroth iteration
w = mat(v)
if w.shape != shape_dtype_struct.shape:
ve = f"shape of `mat(v)` {w.shape!r} incompatible with {shape_dtype_struct}"
if w.shape != swd.shape:
ve = f"shape of `mat(v)` {w.shape!r} incompatible with {swd}"
raise ValueError(ve)
alpha = jnp.dot(w, v)
tridiag = tridiag.at[(0, 0)].set(alpha)
......@@ -43,7 +45,7 @@ def lanczos_tridiag(
def reortho_step(j, state):
vecs, w = state
tau = vecs[j, :].reshape(shape_dtype_struct.shape)
tau = vecs[j, :].reshape(swd.shape)
coeff = jnp.dot(w, tau)
w -= coeff * tau
return vecs, w
......@@ -51,8 +53,10 @@ def lanczos_tridiag(
def lanczos_step(i, state):
tridiag, vecs, beta = state
v = vecs[i, :].reshape(shape_dtype_struct.shape)
v_old = vecs[i - 1, :].reshape(shape_dtype_struct.shape)
# TODO: only save current and last vector and do not
# reorthogonalize??????; check theory beforehand!!!
v = vecs[i, :].reshape(swd.shape)
v_old = vecs[i - 1, :].reshape(swd.shape)
w = mat(v) - beta * v_old
alpha = jnp.dot(w, v)
......@@ -60,7 +64,9 @@ def lanczos_tridiag(
w -= alpha * v
# Full reorthogonalization
vecs, w = jax.lax.fori_loop(0, i, reortho_step, (vecs, w))
# NOTE, in theory the loop could terminate at `i` but this would make
# JAX's default backwards pass not work
vecs, w = fori_loop(0, order, reortho_step, (vecs, w))
# TODO: Raise if lanczos vectors are independent i.e. `beta` small?
beta = jnp.linalg.norm(w)
......@@ -71,18 +77,18 @@ def lanczos_tridiag(
return tridiag, vecs, beta
tridiag, vecs, beta = jax.lax.fori_loop(
tridiag, vecs, beta = fori_loop(
1, order - 1, lanczos_step, (tridiag, vecs, beta)
)
# Final tridiag value and reorthogonalization
v = vecs[order - 1, :].reshape(shape_dtype_struct.shape)
v_old = vecs[order - 2, :].reshape(shape_dtype_struct.shape)
v = vecs[order - 1, :].reshape(swd.shape)
v_old = vecs[order - 2, :].reshape(swd.shape)
w = mat(v) - beta * v_old
alpha = jnp.dot(w, v)
tridiag = tridiag.at[(order - 1, order - 1)].set(alpha)
w -= alpha * v
vecs, w = jax.lax.fori_loop(0, order - 1, reortho_step, (vecs, w))
vecs, w = fori_loop(0, order - 1, reortho_step, (vecs, w))
return (tridiag, vecs)
......@@ -114,7 +120,8 @@ def stochastic_lq_logdet(
key: Union[int, jnp.ndarray],
*,
shape0: Optional[int] = None,
dtype=None
dtype=None,
cmap=jax.vmap,
):
"""Computes a stochastic estimate of the log-determinate of a matrix using
the stochastic Lanczos quadrature algorithm.
......@@ -123,9 +130,12 @@ def stochastic_lq_logdet(
mat = mat.__matmul__ if not hasattr(mat, "__call__") else mat
if not isinstance(key, jnp.ndarray):
key = random.PRNGKey(key)
keys = random.split(key, n_samples)
key_smpls = random.split(key, n_samples)
def random_lanczos(k):
v = random.rademacher(k, (shape0, ), dtype=dtype)
tri, _ = lanczos_tridiag(mat, v, order=order)
return tri
lanczos = partial(lanczos_tridiag, mat, ShapeWithDtype(shape0, dtype))
tridiags, _ = jax.vmap(lanczos, in_axes=(None, 0),
out_axes=(0, 0))(order, keys)
tridiags = cmap(random_lanczos)(key_smpls)
return stochastic_logdet_from_lanczos(tridiags, shape0)
......@@ -3,6 +3,7 @@
# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause
import pytest
pytest.importorskip("jax")
from functools import partial
......@@ -43,9 +44,8 @@ def test_lanczos_tridiag(seed, shape0):
m = rng.normal(size=(shape0, ) * 2)
m = m @ m.T # ensure positive-definiteness
tridiag, vecs = jft.lanczos.lanczos_tridiag(
partial(matmul, m), jft.ShapeWithDtype((shape0, )), shape0, rng_key
)
v = random.rademacher(rng_key, (shape0, ), float)
tridiag, vecs = jft.lanczos.lanczos_tridiag(partial(matmul, m), v, shape0)
m_est = vecs.T @ tridiag @ vecs
assert_allclose(m_est, m, atol=1e-13, rtol=1e-13)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment