Commit 1ca27aa8 authored by Philipp Arras's avatar Philipp Arras

Add type tests and docs

parent 2f1cb237
......@@ -15,9 +15,10 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .energy import Energy
from ..linearization import Linearization
from .. import utilities
from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian
from .energy import Energy
class MetricGaussianKL(Energy):
......@@ -53,6 +54,8 @@ class MetricGaussianKL(Energy):
as they are equaly legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. (default : False)
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Notes
-----
......@@ -61,14 +64,23 @@ class MetricGaussianKL(Energy):
"""
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=None, mirror_samples=False,
point_estimates=[], mirror_samples=False,
_samples=None):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = list(constants)
self._point_estimates = list(point_estimates)
if not isinstance(mirror_samples, bool):
raise TypeError
self._hamiltonian = hamiltonian
self._constants = constants
self._constants_samples = point_estimates
if _samples is None:
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
......@@ -94,7 +106,7 @@ class MetricGaussianKL(Energy):
def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._constants_samples,
self._constants, self._point_estimates,
_samples=self._samples)
@property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment