Commit 56d62904 authored by Philipp Arras's avatar Philipp Arras
Browse files

Simplify Hamiltonian if possible

parent c1fdf306
Pipeline #75576 failed with stages
in 3 minutes and 58 seconds
......@@ -137,7 +137,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None,
nanisinf=False):
nanisinf=False, _ham4eval=None):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -153,6 +153,15 @@ class MetricGaussianKL(Energy):
raise TypeError
self._hamiltonian = hamiltonian
self._ham4eval = _ham4eval
from ..sugar import makeDomain
if self._ham4eval is None:
if len(constants) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
dom = makeDomain(dom)
self._ham4eval = hamiltonian.simplify_for_constant_input(mean.extract(dom))[1]
else:
self._ham4eval = hamiltonian
self._n_samples = int(n_samples)
if comm is not None:
......@@ -170,7 +179,12 @@ class MetricGaussianKL(Energy):
self._n_eff_samples *= 2
if _local_samples is None:
met = hamiltonian(Linearization.make_partial_var(
sample_hamiltonian = hamiltonian
if len(point_estimates) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() if kk in point_estimates}
dom = makeDomain(dom)
sample_hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract(dom))[1]
met = sample_hamiltonian(Linearization.make_partial_var(
mean, self._point_estimates, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
......@@ -187,14 +201,14 @@ class MetricGaussianKL(Energy):
self._lin = Linearization.make_partial_var(mean, self._constants)
v, g = None, None
if len(self._local_samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
tmp = self._ham4eval(self._lin)
v = 0. * tmp.val.val
g = 0. * tmp.gradient
else:
for s in self._local_samples:
tmp = self._hamiltonian(self._lin+s)
tmp = self._ham4eval(self._lin+s)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s)
tmp = tmp + self._ham4eval(self._lin-s)
if v is None:
v = tmp.val.val_rw()
g = tmp.gradient
......@@ -211,7 +225,7 @@ class MetricGaussianKL(Energy):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
_local_samples=self._local_samples, nanisinf=self._mitigate_nans, _ham4eval=self._ham4eval)
@property
def value(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment