Commit d569a79e authored by Philipp Frank's avatar Philipp Frank
Browse files

adjust variances functionality

parent 81663bbf
...@@ -78,6 +78,7 @@ from .library.los_response import LOSResponse ...@@ -78,6 +78,7 @@ from .library.los_response import LOSResponse
from .library.wiener_filter_curvature import WienerFilterCurvature from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.correlated_fields import CorrelatedField, MfCorrelatedField from .library.correlated_fields import CorrelatedField, MfCorrelatedField
from .library.adjust_variances import make_adjust_variances
from . import extra from . import extra
......
from ..operators.energy_operators import InverseGammaLikelihood
from ..operators.scaling_operator import ScalingOperator
def make_adjust_variances(a,xi,position,samples=[],scaling=None):
""" Creates a Likelihood for constant likelihood optimizations.
Constructs a Likelihood to solve constant likelihood optimizations of the form
phi = a * xi
under the constraint that phi remains constant.
Parameters
----------
a : Operator
Operator which gives the amplitude when evaluated at a position
xi : Operator
Operator which gives the excitation when evaluated at a position
postion : Field, MultiField
Position of the whole problem
res_samples : Field, MultiField
Residual samples of the whole Problem
scaling : Float
Optional rescaling of the Likelihood
Returns
-------
InverseGammaLikelihood
A Likelihood that can be used for further minimization
"""
d = a * xi
d = (d.conjugate()*d).real
n = len(samples)
if n>0:
d_eval = 0.
for i in range(n):
d_eval = d_eval + d(position+samples[i])
d_eval = d_eval / n
else:
d_eval = d(position)
x = (a.conjugate()*a).real
if scaling is not None:
x = ScalingOperator(scaling,x.target)(x)
return InverseGammaLikelihood(x,d_eval)
\ No newline at end of file
...@@ -120,7 +120,7 @@ class InverseGammaLikelihood(EnergyOperator): ...@@ -120,7 +120,7 @@ class InverseGammaLikelihood(EnergyOperator):
def apply(self, x): def apply(self, x):
x = self._op(x) x = self._op(x)
res = 0.5*(x.log().sum() + (0.5/x).vdot(self._d)) res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(res) return Field.scalar(res)
if not x.want_metric: if not x.want_metric:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment