Commit c1a9b35b authored by Theo Steininger's avatar Theo Steininger

First draft of log-normal energy and curvature.

parent 74312e64
from critical_filter import *
from log_normal_wiener_filter import *
from wiener_filter import *
# -*- coding: utf-8 -*-
from log_normal_wiener_filter_curvature import LogNormalWienerFilterCurvature
from log_normal_wiener_filter_energy import LogNormalWienerFilterEnergy
from nifty.operators import EndomorphicOperator,\
InvertibleOperatorMixin
from nifty.basic_arithmetics import exp
class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
"""The curvature of the LogNormalWienerFilterEnergy.
This operator implements the second derivative of the
LogNormalWienerFilterEnergy used in some minimization algorithms or for
error estimates of the posterior maps. It is the inverse of the propagator
operator.
Parameters
----------
R: LinearOperator,
The response operator of the Wiener filter measurement.
N : EndomorphicOperator
The noise covariance.
S: DiagonalOperator,
The prior signal covariance
"""
def __init__(self, R, N, S, inverter=None, preconditioner=None, **kwargs):
self.R = R
self.N = N
self.S = S
if preconditioner is None:
preconditioner = self.S.times
self._domain = self.S.domain
super(WienerFilterCurvature, self).__init__(
inverter=inverter,
preconditioner=preconditioner,
**kwargs)
@property
def domain(self):
return self._domain
@property
def self_adjoint(self):
return True
@property
def unitary(self):
return False
# ---Added properties and methods---
def _times(self, x, spaces):
expx = exp(x)
expxx = expx*x
part1 = self.S.inverse_times(x)
part2 = (expx * # is an adjoint necessary here?
self.R.adjoint_times(self.N.inverse_times(self.R(expxx))))
part3 = (expxx * # is an adjoint necessary here?
self.R.adjoint_times(self.N.inverse_times(self.R(expx))))
return part1 + part2 + part3
from nifty.energies.energy import Energy
from nifty.energies.memoization import memo
from nifty.library.wiener_filter import WienerFilterCurvature
from nifty.basic_arithmetics import exp
class LogNormalWienerFilterEnergy(Energy):
"""The Energy for the log-normal Wiener filter.
It covers the case of linear measurement with
Gaussian noise and Gaussain signal prior with known covariance.
Parameters
----------
position: Field,
The current position.
d : Field,
the data.
R : Operator,
The response operator, describtion of the measurement process.
N : EndomorphicOperator,
The noise covariance in data space.
S : EndomorphicOperator,
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S):
super(WienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
self.N = N
self.S = S
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S)
@property
@memo
def value(self):
return (0.5*self.position.vdot(self._Sx) -
self._Rexpxd.vdot(self._NRexpxd))
@property
@memo
def gradient(self):
return self._Sx + self._expx * self.R.adjoint_times(self._NRexpxd)
@property
@memo
def curvature(self):
return WienerFilterCurvature(R=self.R, N=self.N, S=self.S)
@property
@memo
def _expx(self):
return exp(self.position)
@property
@memo
def _Rexpxd(self):
return self.R(self._expx) - self.d
@property
@memo
def _NRexpxd(self):
return self.N.inverse_times(self._Rexpxd)
@property
@memo
def _Sx(self):
return self.S.inverse_times(self.position)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment