Skip to content
Snippets Groups Projects
Commit 174f63eb authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

re.kl.py: Fix tables in docstrings

parent 3525c88d
No related branches found
No related tags found
1 merge request!780Fix docstring formatting
Pipeline #132848 passed
...@@ -32,6 +32,7 @@ def cond_raise(condition, exception): ...@@ -32,6 +32,7 @@ def cond_raise(condition, exception):
call(maybe_raise, condition, result_shape=None) call(maybe_raise, condition, result_shape=None)
# TODO (?): optionally accept ham.metric and likelihood.lsm
def _sample_standard_hamiltonian( def _sample_standard_hamiltonian(
hamiltonian: StandardHamiltonian, hamiltonian: StandardHamiltonian,
primals, primals,
...@@ -144,6 +145,7 @@ def sample_standard_hamiltonian( ...@@ -144,6 +145,7 @@ def sample_standard_hamiltonian(
return inv_met_smpl return inv_met_smpl
# TODO (?): optionally accept ham.metric and likelihood.lsm and likelihood.transformation
def geometrically_sample_standard_hamiltonian( def geometrically_sample_standard_hamiltonian(
hamiltonian: StandardHamiltonian, hamiltonian: StandardHamiltonian,
primals, primals,
...@@ -400,7 +402,6 @@ def MetricKL( ...@@ -400,7 +402,6 @@ def MetricKL(
Parameters Parameters
---------- ----------
hamiltonian : :class:`nifty8.src.re.likelihood.StandardHamiltonian` hamiltonian : :class:`nifty8.src.re.likelihood.StandardHamiltonian`
Hamiltonian of the approximated probability distribution. Hamiltonian of the approximated probability distribution.
primals : :class:`nifty8.re.field.Field` primals : :class:`nifty8.re.field.Field`
...@@ -409,7 +410,7 @@ def MetricKL( ...@@ -409,7 +410,7 @@ def MetricKL(
Number of samples used to stochastically estimate the KL. Number of samples used to stochastically estimate the KL.
key : DeviceArray key : DeviceArray
A PRNG-key. A PRNG-key.
mirror_samples : boolean mirror_samples : bool
Whether the mirrored version of the drawn samples are also used. Whether the mirrored version of the drawn samples are also used.
If true, the number of used samples doubles. If true, the number of used samples doubles.
Mirroring samples stabilizes the KL estimate as extreme Mirroring samples stabilizes the KL estimate as extreme
...@@ -420,10 +421,8 @@ def MetricKL( ...@@ -420,10 +421,8 @@ def MetricKL(
itself. The function is used to map the drawing of samples. Possible itself. The function is used to map the drawing of samples. Possible
string-keys are: string-keys are:
keys - functions - 'pmap' or 'p' for `jax.pmap`
------------------------------------- - 'lax.map' or 'lax' for `jax.lax.map`
'pmap' or 'p' - jax.pmap
'lax.map' or 'lax' - jax.lax.map
In case sample_mapping is passed as a function, it should produce a In case sample_mapping is passed as a function, it should produce a
mapped function f_mapped of a general function f as: `f_mapped = mapped function f_mapped of a general function f as: `f_mapped =
...@@ -557,7 +556,6 @@ def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs): ...@@ -557,7 +556,6 @@ def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs):
Parameters Parameters
---------- ----------
ham : :class:`nifty8.src.re.likelihood.StandardHamiltonian` ham : :class:`nifty8.src.re.likelihood.StandardHamiltonian`
Hamiltonian of the approximated probability distribution, Hamiltonian of the approximated probability distribution,
of which the mean value and the mean gradient are to be computed. of which the mean value and the mean gradient are to be computed.
...@@ -566,11 +564,9 @@ def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs): ...@@ -566,11 +564,9 @@ def mean_value_and_grad(ham: Callable, sample_mapping='vmap', *args, **kwargs):
itself. The function is used to map the drawing of samples. Possible itself. The function is used to map the drawing of samples. Possible
string-keys are: string-keys are:
keys - functions - 'vmap' or 'v' for `jax.vmap`
------------------------------------- - 'pmap' or 'p' for `jax.pmap`
'vmap' or 'v' - jax.vmap - 'lax.map' or 'lax' for `jax.lax.map`
'pmap' or 'p' - jax.pmap
'lax.map' or 'lax' - jax.lax.map
In case sample_mapping is passed as a function, it should produce a In case sample_mapping is passed as a function, it should produce a
mapped function f_mapped of a general function f as: `f_mapped = mapped function f_mapped of a general function f as: `f_mapped =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment