Commit 6251760d authored by Philipp Arras's avatar Philipp Arras

Simplify KL

parent 049e7fa9
Pipeline #75762 passed with stages
in 23 minutes and 29 seconds
......@@ -19,6 +19,7 @@ import numpy as np
from .. import random
from ..linearization import Linearization
from ..logger import logger
from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
......@@ -115,7 +116,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None,
nanisinf=False, _ham4eval=None):
nanisinf=False):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -124,8 +125,6 @@ class MetricGaussianKL(Energy):
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates)
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool):
raise TypeError
......@@ -134,15 +133,11 @@ class MetricGaussianKL(Energy):
'Point estimates for whole domain. Use EnergyAdapter instead.')
self._hamiltonian = hamiltonian
self._ham4eval = _ham4eval
if self._ham4eval is None:
if len(constants) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, self._ham4eval = hamiltonian.simplify_for_constant_input(cstpos)
else:
self._ham4eval = hamiltonian
if len(constants) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
self._n_samples = int(n_samples)
if comm is not None:
......@@ -160,14 +155,13 @@ class MetricGaussianKL(Energy):
self._n_eff_samples *= 2
if _local_samples is None:
sample_hamiltonian = hamiltonian
if len(point_estimates) > 0:
dom = {kk: vv for kk, vv in mean.domain.items()
if kk in point_estimates}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, sample_hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
met = sample_hamiltonian(Linearization.make_var(mean, True)).metric
_, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
met = hamiltonian(Linearization.make_var(mean, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_local_samples = []
......@@ -183,27 +177,26 @@ class MetricGaussianKL(Energy):
self._lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
tmp = self._ham4eval(self._lin+s)
tmp = self._hamiltonian(self._lin+s)
tv = tmp.val.val
tg = tmp.gradient
if self._mirror_samples:
tmp = self._ham4eval(self._lin-s)
tmp = self._hamiltonian(self._lin-s)
tv = tv + tmp.val.val
tg = tg + tmp.gradient
v.append(tv)
g.append(tg)
self._val = self._sumup(v)[()]/self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
if self._mitigate_nans and np.isnan(self._val):
self._val = np.inf
self._grad = self._sumup(g)/self._n_eff_samples
self._metric = None
def at(self, position):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans,
_ham4eval=self._ham4eval)
position, self._hamiltonian, self._n_samples,
mirror_samples=self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
@property
def value(self):
......@@ -217,9 +210,9 @@ class MetricGaussianKL(Energy):
lin = self._lin.with_want_metric()
res = []
for s in self._local_samples:
tmp = self._ham4eval(lin+s).metric(x)
tmp = self._hamiltonian(lin+s).metric(x)
if self._mirror_samples:
tmp = tmp + self._ham4eval(lin-s).metric(x)
tmp = tmp + self._hamiltonian(lin-s).metric(x)
res.append(tmp)
return self._sumup(res)/self._n_eff_samples
......@@ -268,6 +261,10 @@ class MetricGaussianKL(Energy):
def _metric_sample(self, from_inverse=False):
if from_inverse:
raise NotImplementedError()
s = ('This draws from the Hamiltonian used for evaluation and does '
' not take point_estimates into accout. Make sure that this '
'is your intended use.')
logger.warning(s)
lin = self._lin.with_want_metric()
samp = []
sseq = random.spawn_sseq(self._n_samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment