Commit 37047b33 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add maker method to MetricGaussianKL

parent fb0a4dc3
Pipeline #76915 passed with stages
in 11 minutes and 52 seconds
Changes since NIFTy 6
=====================
*None.*
MetricGaussianKL interface
--------------------------
Users do not instanciate `MetricGaussianKL` by its constructor anymore. Rather
`MetricGaussianKL.make()` shall be used.
Changes since NIFTy 5
......
......@@ -131,7 +131,7 @@ def main():
# Draw new samples to approximate the KL five times
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL, convergence = minimizer(KL)
mean = KL.position
......@@ -144,7 +144,7 @@ def main():
name=filename.format("loop_{:02d}".format(i)))
# Draw posterior samples
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator()
for sample in KL.samples:
sc.add(signal(sample + KL.position))
......
......@@ -131,7 +131,7 @@ def main():
for i in range(10):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL, convergence = minimizer(KL)
mean = KL.position
......@@ -157,7 +157,7 @@ def main():
name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator()
scA1 = ift.StatCalculator()
scA2 = ift.StatCalculator()
......
......@@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm
import nifty7 as ift
def main():
dom = ift.UnstructuredDomain(1)
scale = 10
......@@ -90,7 +91,7 @@ def main():
plt.figure(figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL(pos, ham, 40)
mgkl = ift.MetricGaussianKL.make(pos, ham, 40)
plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
......
......@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeDomain, makeOp
from ..sugar import makeOp
from .energy import Energy
......@@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator):
return self._KL._metric_sample(from_inverse)
def _get_lo_hi(comm, n_samples):
ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
return utilities.shareRange(n_samples, ntask, rank)
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
......@@ -58,58 +63,91 @@ class MetricGaussianKL(Energy):
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
Notes
-----
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None,
nanisinf=False):
def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm,
local_samples, nanisinf, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MetricGaussianKL, self).__init__(mean)
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
self._mirror_samples = bool(mirror_samples)
self._comm = comm
self._local_samples = local_samples
self._nanisinf = bool(nanisinf)
lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
tmp = hamiltonian(lin+s)
tv = tmp.val.val
tg = tmp.gradient
if mirror_samples:
tmp = hamiltonian(lin-s)
tv = tv + tmp.val.val
tg = tg + tmp.gradient
v.append(tv)
g.append(tg)
self._val = utilities.allreduce_sum(v, self._comm)[()]/self.n_eff_samples
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples
@staticmethod
def make(mean, hamiltonian, n_samples, constants=[], point_estimates=[],
mirror_samples=False, napprox=0, comm=None, nanisinf=False):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
"""
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
......@@ -117,72 +155,39 @@ class MetricGaussianKL(Energy):
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool):
raise TypeError
if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
raise RuntimeError(
'Point estimates for whole domain. Use EnergyAdapter instead.')
n_samples = int(n_samples)
mirror_samples = bool(mirror_samples)
self._hamiltonian = hamiltonian
if len(constants) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
self._n_samples = int(n_samples)
self._comm = comm
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
if self._mirror_samples:
self._n_eff_samples *= 2
if _local_samples is None:
if len(point_estimates) > 0:
dom = {kk: vv for kk, vv in mean.domain.items()
if kk in point_estimates}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
met = hamiltonian(Linearization.make_var(mean, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_local_samples = []
sseq = random.spawn_sseq(self._n_samples)
for i in range(self._lo, self._hi):
with random.Context(sseq[i]):
_local_samples.append(met.draw_sample(from_inverse=True))
_local_samples = tuple(_local_samples)
if isinstance(mean, MultiField):
cstpos = mean.extract_by_keys(point_estimates)
_, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
else:
if len(_local_samples) != self._hi-self._lo:
raise ValueError("# of samples mismatch")
self._local_samples = _local_samples
self._lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
tmp = self._hamiltonian(self._lin+s)
tv = tmp.val.val
tg = tmp.gradient
if self._mirror_samples:
tmp = self._hamiltonian(self._lin-s)
tv = tv + tmp.val.val
tg = tg + tmp.gradient
v.append(tv)
g.append(tg)
self._val = utilities.allreduce_sum(v, self._comm)[()]/self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self._n_eff_samples
ham_sampling = hamiltonian
met = ham_sampling(Linearization.make_var(mean, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
local_samples = []
sseq = random.spawn_sseq(n_samples)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
local_samples.append(met.draw_sample(from_inverse=True))
local_samples = tuple(local_samples)
if isinstance(mean, MultiField):
_, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants))
return MetricGaussianKL(
mean, hamiltonian, n_samples, mirror_samples, comm, local_samples,
nanisinf, _callingfrommake=True)
def at(self, position):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples,
mirror_samples=self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
position, self._hamiltonian, self._n_samples, self._mirror_samples,
self._comm, self._local_samples, self._nanisinf, True)
@property
def value(self):
......@@ -193,14 +198,20 @@ class MetricGaussianKL(Energy):
return self._grad
def apply_metric(self, x):
lin = self._lin.with_want_metric()
lin = Linearization.make_var(self.position, want_metric=True)
res = []
for s in self._local_samples:
tmp = self._hamiltonian(lin+s).metric(x)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-s).metric(x)
res.append(tmp)
return utilities.allreduce_sum(res, self._comm)/self._n_eff_samples
return utilities.allreduce_sum(res, self._comm)/self.n_eff_samples
@property
def n_eff_samples(self):
if self._mirror_samples:
return 2*self._n_samples
return self._n_samples
@property
def metric(self):
......@@ -216,9 +227,10 @@ class MetricGaussianKL(Energy):
yield -s
else:
rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
lo, _ = _get_lo_hi(self._comm, self._n_samples)
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
data = self._local_samples[i-self._lo] if rank == itask else None
data = self._local_samples[i-lo] if rank == itask else None
s = self._comm.bcast(data, root=itask)
yield s
if self._mirror_samples:
......@@ -231,7 +243,7 @@ class MetricGaussianKL(Energy):
' not take point_estimates into accout. Make sure that this '
'is your intended use.')
logger.warning(s)
lin = self._lin.with_want_metric()
lin = Linearization.make_var(self.position, True)
samp = []
sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._local_samples):
......@@ -240,4 +252,4 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
samp.append(tmp)
return utilities.allreduce_sum(samp, self._comm)/self._n_eff_samples
return utilities.allreduce_sum(samp, self._comm)/self.n_eff_samples
......@@ -248,6 +248,10 @@ class MultiField(Operator):
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
def extract_by_keys(self, keys):
dom = MultiDomain.make({kk: vv for kk, vv in self.domain.items() if kk in keys})
return self.extract(dom)
def extract_part(self, subset):
if subset is self._domain:
return self
......
......@@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta):
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None:
return None, self
dom = c_inp.domain
if isinstance(dom, MultiDomain) and len(dom) == 0:
return None, self
# Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain
if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain)
assert isinstance(dom, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
raise ValueError
if c_inp.domain is self.domain:
if dom is self.domain:
if isinstance(self, EnergyOperator):
op = ConstantEnergyOperator(self.domain, self(c_inp))
else:
op = ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
if not isinstance(c_inp.domain, MultiDomain):
if not isinstance(dom, MultiDomain):
raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp)
......
......@@ -520,7 +520,7 @@ def calculate_position(operator, output):
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3):
logger.info(f'Start iteration {ii+1}/3')
kl = MetricGaussianKL(pos, H, 3, mirror_samples=True)
kl = MetricGaussianKL.make(pos, H, 3, mirror_samples=True)
kl, _ = minimizer(kl)
pos = kl.position
return pos
......@@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args)
ift.MetricGaussianKL.make(**args)
return
kl = ift.MetricGaussianKL(**args)
kl = ift.MetricGaussianKL.make(**args)
assert_(len(ic.history) > 0)
assert_(len(ic.history) == len(ic.history.time_stamps))
assert_(len(ic.history) == len(ic.history.energy_values))
......@@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
assert_(len(ic.history) == len(ic.history.energy_values))
locsamp = kl._local_samples
klpure = ift.MetricGaussianKL(mean0,
h,
nsamps,
mirror_samples=mirror_samples,
constants=constants,
point_estimates=point_estimates,
_local_samples=locsamp)
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
klpure = ift.MetricGaussianKL(mean0, tmph, nsamps, mirror_samples, None, locsamp, False, True)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......
......@@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args, comm=comm)
ift.MetricGaussianKL.make(**args, comm=comm)
return
if mode == 0:
kl0 = ift.MetricGaussianKL(**args, comm=comm)
kl0 = ift.MetricGaussianKL.make(**args, comm=comm)
locsamp = kl0._local_samples
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True)
elif mode == 1:
kl0 = ift.MetricGaussianKL(**args)
kl0 = ift.MetricGaussianKL.make(**args)
samples = kl0._local_samples
ii = len(samples)//2
slc = slice(None, ii) if rank == 0 else slice(ii, None)
locsamp = samples[slc]
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......
......@@ -16,11 +16,15 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import pytest
from numpy.testing import assert_equal
import nifty7 as ift
from .common import setup_function, teardown_function
pmp = pytest.mark.parametrize
def test_get_signal_variance():
space = ift.RGSpace(3)
......@@ -45,15 +49,13 @@ def test_exec_time():
lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op1
ic = ift.GradientNormController(iteration_limit=2)
ham = ift.StandardHamiltonian(lh, ic_samp=ic)
kl = ift.MetricGaussianKL(ift.full(ham.domain, 0.), ham, 1)
kl = ift.MetricGaussianKL.make(ift.full(ham.domain, 0.), ham, 1)
ops = [op, op1, lh, ham, kl]
for oo in ops:
for wm in [True, False]:
ift.exec_time(oo, wm)
import pytest
pmp = pytest.mark.parametrize
@pmp('mf', [False, True])
@pmp('cplx', [False, True])
def test_calc_pos(mf, cplx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment