Skip to content
Snippets Groups Projects
Commit c58224d8 authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

stats_distributions: Return trafos suitable for within jit

parent 5c40e9ee
Branches
Tags
1 merge request!904Re likelihood
......@@ -2,7 +2,7 @@ from functools import partial
from typing import Callable, Optional
from jax import numpy as jnp
from jax.tree_util import tree_map
from jax.tree_util import Partial, tree_map
from ..tree_math.vector_math import any as tree_any
......@@ -12,7 +12,17 @@ log = partial(tree_map, jnp.log)
log1p = partial(tree_map, jnp.log1p)
def laplace_prior(alpha) -> Callable:
def _standard_to_laplace(xi, *, alpha):
from jax.scipy.stats import norm
norm_logcdf = partial(tree_map, norm.logcdf)
res = (xi < 0) * (norm_logcdf(xi) + jnp.log(2))
res -= (xi > 0) * (norm_logcdf(-xi) + jnp.log(2))
return res * alpha
def laplace_prior(alpha) -> Partial:
"""
Takes random normal samples and outputs samples distributed according to
......@@ -20,34 +30,28 @@ def laplace_prior(alpha) -> Callable:
P(x|a) = exp(-|x|/a)/a/2
"""
from jax.scipy.stats import norm
norm_logcdf = partial(tree_map, norm.logcdf)
return Partial(_standard_to_laplace, alpha=alpha)
def standard_to_laplace(xi):
res = (xi < 0) * (norm_logcdf(xi) + jnp.log(2))
res -= (xi > 0) * (norm_logcdf(-xi) + jnp.log(2))
return res * alpha
return standard_to_laplace
def _standard_to_normal(xi, *, mean, std):
return mean + std * xi
def normal_prior(mean, std) -> Callable:
def normal_prior(mean, std) -> Partial:
"""Match standard normally distributed random variables to non-standard
variables.
"""
def standard_to_normal(xi):
return mean + std * xi
return Partial(_standard_to_normal, mean=mean, std=std)
return standard_to_normal
def _normal_to_standard(y, *, mean, std):
return (y - mean) / std
def normal_invprior(mean, std) -> Callable:
"""Get the inverse transform to `normal_prior`."""
def normal_to_standard(y):
return (y - mean) / std
return normal_to_standard
def normal_invprior(mean, std) -> Partial:
"""Get the inverse transform to `normal_prior`."""
return Partial(_normal_to_standard, mean=mean, std=std)
def lognormal_moments(mean, std):
......@@ -64,7 +68,11 @@ def lognormal_moments(mean, std):
return logmean, logstd
def lognormal_prior(mean, std, *, _log_mean=None, _log_std=None) -> Callable:
def _standard_to_lognormal(xi, *, log_mean, log_std):
return exp(_standard_to_normal(xi, mean=log_mean, std=log_std))
def lognormal_prior(mean, std, *, _log_mean=None, _log_std=None) -> Partial:
"""Moment-match standard normally distributed random variables to log-space
Takes random normal samples and outputs samples distributed according to
......@@ -75,31 +83,29 @@ def lognormal_prior(mean, std, *, _log_mean=None, _log_std=None) -> Callable:
such that the mean and standard deviation of the distribution matches the
specified values.
"""
if _log_mean is not None and _log_std is not None:
standard_to_normal = normal_prior(_log_mean, _log_std)
else:
standard_to_normal = normal_prior(*lognormal_moments(mean, std))
if _log_mean is None and _log_std is None:
_log_mean, _log_std = lognormal_moments(mean, std)
return Partial(_standard_to_lognormal, log_mean=_log_mean, log_std=_log_std)
def standard_to_lognormal(xi):
return exp(standard_to_normal(xi))
return standard_to_lognormal
def _lognormal_to_standard(y, *, log_mean, log_std):
return _normal_to_standard(log(y), mean=log_mean, std=log_std)
def lognormal_invprior(mean, std, *, _log_mean=None, _log_std=None) -> Callable:
def lognormal_invprior(mean, std, *, _log_mean=None, _log_std=None) -> Partial:
"""Get the inverse transform to `lognormal_prior`."""
if _log_mean is not None and _log_std is not None:
ln_m, ln_std = _log_mean, _log_std
else:
ln_m, ln_std = lognormal_moments(mean, std)
if _log_mean is None and _log_std is None:
_log_mean, _log_std = lognormal_moments(mean, std)
return Partial(_lognormal_to_standard, log_mean=_log_mean, log_std=_log_std)
def lognormal_to_standard(y):
return (log(y) - ln_m) / ln_std
def _standard_to_uniform(xi, *, a_min, scale):
from jax.scipy.stats import norm
return lognormal_to_standard
return a_min + scale * tree_map(norm.cdf, xi)
def uniform_prior(a_min=0., a_max=1.) -> Callable:
def uniform_prior(a_min=0., a_max=1.) -> Partial:
"""Transform a standard normal into a uniform distribution.
Parameters
......@@ -111,18 +117,13 @@ def uniform_prior(a_min=0., a_max=1.) -> Callable:
"""
from jax.scipy.stats import norm
norm_cdf = partial(tree_map, norm.cdf)
scale = a_max - a_min
if isinstance(a_min, float) and isinstance(
a_max, float
) and a_min == 0. and a_max == 1.:
return norm_cdf
return Partial(partial(tree_map, norm.cdf))
def standard_to_uniform(xi):
return a_min + scale * norm_cdf(xi)
return standard_to_uniform
scale = a_max - a_min
return Partial(_standard_to_uniform, a_min=a_min, scale=scale)
def interpolator(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment