Commit 8a8e6751 authored by Philipp Arras's avatar Philipp Arras
Browse files

Gaussian energy takes inverse covariance as input

This is a performance tweak. Since inverse diagonal operators are automatically
converted to a pure diagonal operator with the inverse on its diagonal, this
commit saves memory.
parent 369354fc
Pipeline #50164 passed with stages
in 7 minutes and 53 seconds
...@@ -109,7 +109,8 @@ if __name__ == '__main__': ...@@ -109,7 +109,8 @@ if __name__ == '__main__':
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
# Set up likelihood and information Hamiltonian # Set up likelihood and information Hamiltonian
likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response) likelihood = ift.GaussianEnergy(mean=data,
inverse_covariance=N.inverse)(signal_response)
H = ift.StandardHamiltonian(likelihood, ic_sampling) H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_mean = ift.MultiField.full(H.domain, 0.) initial_mean = ift.MultiField.full(H.domain, 0.)
......
...@@ -110,8 +110,8 @@ class GaussianEnergy(EnergyOperator): ...@@ -110,8 +110,8 @@ class GaussianEnergy(EnergyOperator):
---------- ----------
mean : Field mean : Field
Mean of the Gaussian. Default is 0. Mean of the Gaussian. Default is 0.
covariance : LinearOperator inverse_covariance : LinearOperator
Covariance of the Gaussian. Default is the identity operator. Inverse covariance of the Gaussian. Default is the identity operator.
domain : Domain, DomainTuple, tuple of Domain or MultiDomain domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Operator domain. By default it is inferred from `mean` or Operator domain. By default it is inferred from `mean` or
`covariance` if specified `covariance` if specified
...@@ -121,28 +121,27 @@ class GaussianEnergy(EnergyOperator): ...@@ -121,28 +121,27 @@ class GaussianEnergy(EnergyOperator):
At least one of the arguments has to be provided. At least one of the arguments has to be provided.
""" """
def __init__(self, mean=None, covariance=None, domain=None): def __init__(self, mean=None, inverse_covariance=None, domain=None):
if mean is not None and not isinstance(mean, (Field, MultiField)): if mean is not None and not isinstance(mean, (Field, MultiField)):
raise TypeError raise TypeError
if covariance is not None and not isinstance(covariance, if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
LinearOperator):
raise TypeError raise TypeError
self._domain = None self._domain = None
if mean is not None: if mean is not None:
self._checkEquivalence(mean.domain) self._checkEquivalence(mean.domain)
if covariance is not None: if inverse_covariance is not None:
self._checkEquivalence(covariance.domain) self._checkEquivalence(inverse_covariance.domain)
if domain is not None: if domain is not None:
self._checkEquivalence(domain) self._checkEquivalence(domain)
if self._domain is None: if self._domain is None:
raise ValueError("no domain given") raise ValueError("no domain given")
self._mean = mean self._mean = mean
if covariance is None: if inverse_covariance is None:
self._op = SquaredNormOperator(self._domain).scale(0.5) self._op = SquaredNormOperator(self._domain).scale(0.5)
else: else:
self._op = QuadraticFormOperator(covariance.inverse) self._op = QuadraticFormOperator(inverse_covariance)
self._icov = None if covariance is None else covariance.inverse self._icov = None if inverse_covariance is None else inverse_covariance
def _checkEquivalence(self, newdom): def _checkEquivalence(self, newdom):
newdom = makeDomain(newdom) newdom = makeDomain(newdom)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment