Skip to content
Snippets Groups Projects
Commit 8a8e6751 authored by Philipp Arras's avatar Philipp Arras
Browse files

Gaussian energy takes inverse covariance as input

This is a performance tweak. Since inverse diagonal operators are automatically
converted to a pure diagonal operator with the inverse on its diagonal, this
commit saves memory.
parent 369354fc
No related branches found
No related tags found
1 merge request!327Power grid
Pipeline #50164 passed
......@@ -109,7 +109,8 @@ if __name__ == '__main__':
minimizer = ift.NewtonCG(ic_newton)
# Set up likelihood and information Hamiltonian
likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response)
likelihood = ift.GaussianEnergy(mean=data,
inverse_covariance=N.inverse)(signal_response)
H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_mean = ift.MultiField.full(H.domain, 0.)
......
......@@ -110,8 +110,8 @@ class GaussianEnergy(EnergyOperator):
----------
mean : Field
Mean of the Gaussian. Default is 0.
covariance : LinearOperator
Covariance of the Gaussian. Default is the identity operator.
inverse_covariance : LinearOperator
Inverse covariance of the Gaussian. Default is the identity operator.
domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Operator domain. By default it is inferred from `mean` or
`covariance` if specified
......@@ -121,28 +121,27 @@ class GaussianEnergy(EnergyOperator):
At least one of the arguments has to be provided.
"""
def __init__(self, mean=None, covariance=None, domain=None):
def __init__(self, mean=None, inverse_covariance=None, domain=None):
if mean is not None and not isinstance(mean, (Field, MultiField)):
raise TypeError
if covariance is not None and not isinstance(covariance,
LinearOperator):
if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
raise TypeError
self._domain = None
if mean is not None:
self._checkEquivalence(mean.domain)
if covariance is not None:
self._checkEquivalence(covariance.domain)
if inverse_covariance is not None:
self._checkEquivalence(inverse_covariance.domain)
if domain is not None:
self._checkEquivalence(domain)
if self._domain is None:
raise ValueError("no domain given")
self._mean = mean
if covariance is None:
if inverse_covariance is None:
self._op = SquaredNormOperator(self._domain).scale(0.5)
else:
self._op = QuadraticFormOperator(covariance.inverse)
self._icov = None if covariance is None else covariance.inverse
self._op = QuadraticFormOperator(inverse_covariance)
self._icov = None if inverse_covariance is None else inverse_covariance
def _checkEquivalence(self, newdom):
newdom = makeDomain(newdom)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment