Skip to content
Snippets Groups Projects
Commit 70a4c327 authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

likelihood: Remove most refs. to Hamiltonian

parent 4a3188ee
No related branches found
No related tags found
1 merge request!904Re likelihood
......@@ -215,7 +215,7 @@ class Likelihood(AbstractModel):
"""
# TODO: track forward model and build lsm, metric, residual only when
# called instead of always partially
self._hamiltonian = energy
self._energy = energy
self._transformation = transformation
self._normalized_residual = normalized_residual
self._left_sqrt_metric = left_sqrt_metric
......@@ -260,7 +260,7 @@ class Likelihood(AbstractModel):
energy : float
Energy at the position `primals`.
"""
return self._hamiltonian(primals, **primals_kw)
return self._energy(primals, **primals_kw)
def normalized_residual(self, primals, **primals_kw):
"""Applies the normalized_residual to `primals`.
......@@ -490,7 +490,7 @@ class Likelihood(AbstractModel):
j_trafo, j_lsm, j_rsm, j_m = None, None, None, None
return self.replace(
jit(self._hamiltonian, **kwargs),
jit(self._energy, **kwargs),
normalized_residual=j_r,
transformation=j_trafo,
left_sqrt_metric=j_lsm,
......@@ -613,7 +613,7 @@ class Likelihood(AbstractModel):
Vector) or isinstance(other._lsm_tan_shp, Vector):
joined_tangents_shape = Vector(joined_tangents_shape)
def joined_hamiltonian(p, **pkw):
def joined_energy(p, **pkw):
return self.energy(p, **pkw) + other.energy(p, **pkw)
def joined_normalized_residual(p, **pkw):
......@@ -668,7 +668,7 @@ class Likelihood(AbstractModel):
raise ValueError(ve)
return Likelihood(
joined_hamiltonian,
joined_energy,
normalized_residual=joined_normalized_residual,
transformation=joined_transformation,
left_sqrt_metric=joined_left_sqrt_metric,
......@@ -778,7 +778,7 @@ class StandardHamiltonian():
"""
self._lh = likelihood
def joined_hamiltonian(primals, **primals_kw):
def joined_energy(primals, **primals_kw):
# Assume the first primals to be the parameters
return self._lh(primals, **
primals_kw) + 0.5 * vdot(primals, primals)
......@@ -788,9 +788,9 @@ class StandardHamiltonian():
if _compile_joined:
from jax import jit
joined_hamiltonian = jit(joined_hamiltonian, **_compile_kwargs)
joined_energy = jit(joined_energy, **_compile_kwargs)
joined_metric = jit(joined_metric, **_compile_kwargs)
self._hamiltonian = joined_hamiltonian
self._energy = joined_energy
self._metric = joined_metric
@doc_from(Likelihood.__call__)
......@@ -799,7 +799,7 @@ class StandardHamiltonian():
@doc_from(Likelihood.energy)
def energy(self, primals, **primals_kw):
return self._hamiltonian(primals, **primals_kw)
return self._energy(primals, **primals_kw)
@doc_from(Likelihood.metric)
def metric(self, primals, tangents, **primals_kw):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment