Skip to content
Snippets Groups Projects
Commit 174f63eb authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

re.kl.py: Fix tables in docstrings

parent 3525c88d
No related branches found
No related tags found
1 merge request!780Fix docstring formatting
Pipeline #132848 passed
......@@ -32,6 +32,7 @@ def cond_raise(condition, exception):
call(maybe_raise, condition, result_shape=None)
# TODO (?): optionally accept ham.metric and likelihood.lsm
def _sample_standard_hamiltonian(
hamiltonian: StandardHamiltonian,
primals,
......@@ -144,6 +145,7 @@ def sample_standard_hamiltonian(
return inv_met_smpl
# TODO (?): optionally accept ham.metric and likelihood.lsm and likelihood.transformation
def geometrically_sample_standard_hamiltonian(
hamiltonian: StandardHamiltonian,
primals,
......@@ -400,7 +402,6 @@ def MetricKL(
Parameters
----------
hamiltonian : :class:`nifty8.src.re.likelihood.StandardHamiltonian`
Hamiltonian of the approximated probability distribution.
primals : :class:`nifty8.re.field.Field`
......@@ -409,7 +410,7 @@ def MetricKL(
Number of samples used to stochastically estimate the KL.
key : DeviceArray
A PRNG-key.
mirror_samples : boolean
mirror_samples : bool
Whether the mirrored version of the drawn samples are also used.
If true, the number of used samples doubles.
Mirroring samples stabilizes the KL estimate as extreme
......@@ -420,10 +421,8 @@ def MetricKL(
itself. The function is used to map the drawing of samples. Possible
string-keys are:
keys - functions
-------------------------------------
'pmap' or 'p' - jax.pmap
'lax.map' or 'lax' - jax.lax.map
- 'pmap' or 'p' for `jax.pmap`
- 'lax.map' or 'lax' for `jax.lax.map`
In case sample_mapping is passed as a function, it should produce a
mapped function f_mapped of a general function f as: `f_mapped =
......@@ -557,7 +556,6 @@ def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs):
Parameters
----------
ham : :class:`nifty8.src.re.likelihood.StandardHamiltonian`
Hamiltonian of the approximated probability distribution,
of which the mean value and the mean gradient are to be computed.
......@@ -566,11 +564,9 @@ def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs):
itself. The function is used to map the drawing of samples. Possible
string-keys are:
keys - functions
-------------------------------------
'vmap' or 'v' - jax.vmap
'pmap' or 'p' - jax.pmap
'lax.map' or 'lax' - jax.lax.map
- 'vmap' or 'v' for `jax.vmap`
- 'pmap' or 'p' for `jax.pmap`
- 'lax.map' or 'lax' for `jax.lax.map`
In case sample_mapping is passed as a function, it should produce a
mapped function f_mapped of a general function f as: `f_mapped =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment