Commit 6251760d authored by Philipp Arras's avatar Philipp Arras
Browse files

Simplify KL

parent 049e7fa9
Pipeline #75762 passed with stages
in 23 minutes and 29 seconds
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from .. import random from .. import random
from ..linearization import Linearization from ..linearization import Linearization
from ..logger import logger
from ..multi_field import MultiField from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian from ..operators.energy_operators import StandardHamiltonian
...@@ -115,7 +116,7 @@ class MetricGaussianKL(Energy): ...@@ -115,7 +116,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False, point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None, napprox=0, comm=None, _local_samples=None,
nanisinf=False, _ham4eval=None): nanisinf=False):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian): if not isinstance(hamiltonian, StandardHamiltonian):
...@@ -124,8 +125,6 @@ class MetricGaussianKL(Energy): ...@@ -124,8 +125,6 @@ class MetricGaussianKL(Energy):
raise ValueError raise ValueError
if not isinstance(n_samples, int): if not isinstance(n_samples, int):
raise TypeError raise TypeError
self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates)
self._mitigate_nans = nanisinf self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool): if not isinstance(mirror_samples, bool):
raise TypeError raise TypeError
...@@ -134,15 +133,11 @@ class MetricGaussianKL(Energy): ...@@ -134,15 +133,11 @@ class MetricGaussianKL(Energy):
'Point estimates for whole domain. Use EnergyAdapter instead.') 'Point estimates for whole domain. Use EnergyAdapter instead.')
self._hamiltonian = hamiltonian self._hamiltonian = hamiltonian
self._ham4eval = _ham4eval if len(constants) > 0:
if self._ham4eval is None: dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
if len(constants) > 0: dom = makeDomain(dom)
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants} cstpos = mean.extract(dom)
dom = makeDomain(dom) _, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
cstpos = mean.extract(dom)
_, self._ham4eval = hamiltonian.simplify_for_constant_input(cstpos)
else:
self._ham4eval = hamiltonian
self._n_samples = int(n_samples) self._n_samples = int(n_samples)
if comm is not None: if comm is not None:
...@@ -160,14 +155,13 @@ class MetricGaussianKL(Energy): ...@@ -160,14 +155,13 @@ class MetricGaussianKL(Energy):
self._n_eff_samples *= 2 self._n_eff_samples *= 2
if _local_samples is None: if _local_samples is None:
sample_hamiltonian = hamiltonian
if len(point_estimates) > 0: if len(point_estimates) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() dom = {kk: vv for kk, vv in mean.domain.items()
if kk in point_estimates} if kk in point_estimates}
dom = makeDomain(dom) dom = makeDomain(dom)
cstpos = mean.extract(dom) cstpos = mean.extract(dom)
_, sample_hamiltonian = hamiltonian.simplify_for_constant_input(cstpos) _, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
met = sample_hamiltonian(Linearization.make_var(mean, True)).metric met = hamiltonian(Linearization.make_var(mean, True)).metric
if napprox >= 1: if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox)) met._approximation = makeOp(approximation2endo(met, napprox))
_local_samples = [] _local_samples = []
...@@ -183,27 +177,26 @@ class MetricGaussianKL(Energy): ...@@ -183,27 +177,26 @@ class MetricGaussianKL(Energy):
self._lin = Linearization.make_var(mean) self._lin = Linearization.make_var(mean)
v, g = [], [] v, g = [], []
for s in self._local_samples: for s in self._local_samples:
tmp = self._ham4eval(self._lin+s) tmp = self._hamiltonian(self._lin+s)
tv = tmp.val.val tv = tmp.val.val
tg = tmp.gradient tg = tmp.gradient
if self._mirror_samples: if self._mirror_samples:
tmp = self._ham4eval(self._lin-s) tmp = self._hamiltonian(self._lin-s)
tv = tv + tmp.val.val tv = tv + tmp.val.val
tg = tg + tmp.gradient tg = tg + tmp.gradient
v.append(tv) v.append(tv)
g.append(tg) g.append(tg)
self._val = self._sumup(v)[()]/self._n_eff_samples self._val = self._sumup(v)[()]/self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans: if self._mitigate_nans and np.isnan(self._val):
self._val = np.inf self._val = np.inf
self._grad = self._sumup(g)/self._n_eff_samples self._grad = self._sumup(g)/self._n_eff_samples
self._metric = None self._metric = None
def at(self, position): def at(self, position):
return MetricGaussianKL( return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants, position, self._hamiltonian, self._n_samples,
self._point_estimates, self._mirror_samples, comm=self._comm, mirror_samples=self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans, _local_samples=self._local_samples, nanisinf=self._mitigate_nans)
_ham4eval=self._ham4eval)
@property @property
def value(self): def value(self):
...@@ -217,9 +210,9 @@ class MetricGaussianKL(Energy): ...@@ -217,9 +210,9 @@ class MetricGaussianKL(Energy):
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
res = [] res = []
for s in self._local_samples: for s in self._local_samples:
tmp = self._ham4eval(lin+s).metric(x) tmp = self._hamiltonian(lin+s).metric(x)
if self._mirror_samples: if self._mirror_samples:
tmp = tmp + self._ham4eval(lin-s).metric(x) tmp = tmp + self._hamiltonian(lin-s).metric(x)
res.append(tmp) res.append(tmp)
return self._sumup(res)/self._n_eff_samples return self._sumup(res)/self._n_eff_samples
...@@ -268,6 +261,10 @@ class MetricGaussianKL(Energy): ...@@ -268,6 +261,10 @@ class MetricGaussianKL(Energy):
def _metric_sample(self, from_inverse=False): def _metric_sample(self, from_inverse=False):
if from_inverse: if from_inverse:
raise NotImplementedError() raise NotImplementedError()
s = ('This draws from the Hamiltonian used for evaluation and does '
' not take point_estimates into accout. Make sure that this '
'is your intended use.')
logger.warning(s)
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
samp = [] samp = []
sseq = random.spawn_sseq(self._n_samples) sseq = random.spawn_sseq(self._n_samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment