Commit 62e6b774 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add do_adjust_variances

parent 398005ee
......@@ -78,7 +78,7 @@ from .library.los_response import LOSResponse
from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.correlated_fields import CorrelatedField, MfCorrelatedField
from .library.adjust_variances import make_adjust_variances
from .library.adjust_variances import make_adjust_variances, do_adjust_variances
from . import extra
......
......@@ -19,11 +19,20 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy_adapter import EnergyAdapter
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..operators.distributors import PowerDistributor
from ..operators.energy_operators import Hamiltonian, InverseGammaLikelihood
from ..operators.scaling_operator import ScalingOperator
from ..operators.simple_linear_operators import FieldAdapter
def make_adjust_variances(a, xi, position, samples=[], scaling=None,
def make_adjust_variances(a,
xi,
position,
samples=[],
scaling=None,
ic_samp=None):
""" Creates a Hamiltonian for constant likelihood optimizations.
......@@ -67,3 +76,36 @@ def make_adjust_variances(a, xi, position, samples=[], scaling=None,
x = ScalingOperator(scaling, x.target)(x)
return Hamiltonian(InverseGammaLikelihood(d_eval)(x), ic_samp=ic_samp)
def do_adjust_variances(position, xi, amplitude_model, minimizer, samples=[]):
h_space = xi.target[0]
pd = PowerDistributor(h_space, amplitude_model.target[0])
a = pd(amplitude_model)
xi = FieldAdapter(MultiDomain.make({"xi": h_space}), "xi")
axi_domain = MultiDomain.union([a.domain, xi.domain])
ham = make_adjust_variances(
a, xi, position.extract(axi_domain), samples=samples)
a_pos = position.extract(a.domain)
# Minimize
# FIXME Try also KL here
e = EnergyAdapter(a_pos, ham, constants=[], want_metric=True)
e, _ = minimizer(e)
# Update position
s_h_old = (a*xi)(position.extract(axi_domain))
position = position.to_dict()
position['xi'] = s_h_old/a(e.position)
position = MultiField.from_dict(position)
position = MultiField.union([position, e.position])
s_h_new = (a*xi)(position.extract(axi_domain))
import numpy as np
# TODO Move this into the tests
np.testing.assert_allclose(s_h_new.to_global_data(),
s_h_old.to_global_data())
return position
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment