Commit 1ca27aa8 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add type tests and docs

parent 2f1cb237
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .energy import Energy
from ..linearization import Linearization
from .. import utilities from .. import utilities
from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian
from .energy import Energy
class MetricGaussianKL(Energy): class MetricGaussianKL(Energy):
...@@ -53,6 +54,8 @@ class MetricGaussianKL(Energy): ...@@ -53,6 +54,8 @@ class MetricGaussianKL(Energy):
as they are equaly legitimate samples. If true, the number of used as they are equaly legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. (default : False) extreme sample variation is counterbalanced. (default : False)
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Notes Notes
----- -----
...@@ -61,14 +64,23 @@ class MetricGaussianKL(Energy): ...@@ -61,14 +64,23 @@ class MetricGaussianKL(Energy):
""" """
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=None, mirror_samples=False, point_estimates=[], mirror_samples=False,
_samples=None): _samples=None):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain: if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = list(constants)
self._point_estimates = list(point_estimates)
if not isinstance(mirror_samples, bool):
raise TypeError raise TypeError
self._hamiltonian = hamiltonian self._hamiltonian = hamiltonian
self._constants = constants
self._constants_samples = point_estimates
if _samples is None: if _samples is None:
met = hamiltonian(Linearization.make_partial_var( met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric mean, point_estimates, True)).metric
...@@ -94,7 +106,7 @@ class MetricGaussianKL(Energy): ...@@ -94,7 +106,7 @@ class MetricGaussianKL(Energy):
def at(self, position): def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0, return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._constants_samples, self._constants, self._point_estimates,
_samples=self._samples) _samples=self._samples)
@property @property
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment