Commit 37047b33 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add maker method to MetricGaussianKL

parent fb0a4dc3
Pipeline #76915 passed with stages
in 11 minutes and 52 seconds
Changes since NIFTy 6
=====================
*None.*
MetricGaussianKL interface
--------------------------
Users do not instanciate `MetricGaussianKL` by its constructor anymore. Rather
`MetricGaussianKL.make()` shall be used.
Changes since NIFTy 5
......
......@@ -131,7 +131,7 @@ def main():
# Draw new samples to approximate the KL five times
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL, convergence = minimizer(KL)
mean = KL.position
......@@ -144,7 +144,7 @@ def main():
name=filename.format("loop_{:02d}".format(i)))
# Draw posterior samples
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator()
for sample in KL.samples:
sc.add(signal(sample + KL.position))
......
......@@ -131,7 +131,7 @@ def main():
for i in range(10):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL, convergence = minimizer(KL)
mean = KL.position
......@@ -157,7 +157,7 @@ def main():
name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator()
scA1 = ift.StatCalculator()
scA2 = ift.StatCalculator()
......
......@@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm
import nifty7 as ift
def main():
dom = ift.UnstructuredDomain(1)
scale = 10
......@@ -90,7 +91,7 @@ def main():
plt.figure(figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL(pos, ham, 40)
mgkl = ift.MetricGaussianKL.make(pos, ham, 40)
plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
......
......@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeDomain, makeOp
from ..sugar import makeOp
from .energy import Energy
......@@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator):
return self._KL._metric_sample(from_inverse)
def _get_lo_hi(comm, n_samples):
ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
return utilities.shareRange(n_samples, ntask, rank)
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
......@@ -58,6 +63,50 @@ class MetricGaussianKL(Energy):
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Notes
-----
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm,
local_samples, nanisinf, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(MetricGaussianKL, self).__init__(mean)
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
self._mirror_samples = bool(mirror_samples)
self._comm = comm
self._local_samples = local_samples
self._nanisinf = bool(nanisinf)
lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
tmp = hamiltonian(lin+s)
tv = tmp.val.val
tg = tmp.gradient
if mirror_samples:
tmp = hamiltonian(lin-s)
tv = tv + tmp.val.val
tg = tg + tmp.gradient
v.append(tv)
g.append(tg)
self._val = utilities.allreduce_sum(v, self._comm)[()]/self.n_eff_samples
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples
@staticmethod
def make(mean, hamiltonian, n_samples, constants=[], point_estimates=[],
mirror_samples=False, napprox=0, comm=None, nanisinf=False):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
----------
mean : Field
......@@ -98,91 +147,47 @@ class MetricGaussianKL(Energy):
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _local_samples=None,
nanisinf=False):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool):
raise TypeError
if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
raise RuntimeError(
'Point estimates for whole domain. Use EnergyAdapter instead.')
n_samples = int(n_samples)
mirror_samples = bool(mirror_samples)
self._hamiltonian = hamiltonian
if len(constants) > 0:
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
self._n_samples = int(n_samples)
self._comm = comm
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
if self._mirror_samples:
self._n_eff_samples *= 2
if _local_samples is None:
if len(point_estimates) > 0:
dom = {kk: vv for kk, vv in mean.domain.items()
if kk in point_estimates}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
met = hamiltonian(Linearization.make_var(mean, True)).metric
if isinstance(mean, MultiField):
cstpos = mean.extract_by_keys(point_estimates)
_, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
else:
ham_sampling = hamiltonian
met = ham_sampling(Linearization.make_var(mean, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_local_samples = []
sseq = random.spawn_sseq(self._n_samples)
for i in range(self._lo, self._hi):
local_samples = []
sseq = random.spawn_sseq(n_samples)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
_local_samples.append(met.draw_sample(from_inverse=True))
_local_samples = tuple(_local_samples)
else:
if len(_local_samples) != self._hi-self._lo:
raise ValueError("# of samples mismatch")
self._local_samples = _local_samples
self._lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
tmp = self._hamiltonian(self._lin+s)
tv = tmp.val.val
tg = tmp.gradient
if self._mirror_samples:
tmp = self._hamiltonian(self._lin-s)
tv = tv + tmp.val.val
tg = tg + tmp.gradient
v.append(tv)
g.append(tg)
self._val = utilities.allreduce_sum(v, self._comm)[()]/self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self._n_eff_samples
local_samples.append(met.draw_sample(from_inverse=True))
local_samples = tuple(local_samples)
if isinstance(mean, MultiField):
_, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants))
return MetricGaussianKL(
mean, hamiltonian, n_samples, mirror_samples, comm, local_samples,
nanisinf, _callingfrommake=True)
def at(self, position):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples,
mirror_samples=self._mirror_samples, comm=self._comm,
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
position, self._hamiltonian, self._n_samples, self._mirror_samples,
self._comm, self._local_samples, self._nanisinf, True)
@property
def value(self):
......@@ -193,14 +198,20 @@ class MetricGaussianKL(Energy):
return self._grad
def apply_metric(self, x):
lin = self._lin.with_want_metric()
lin = Linearization.make_var(self.position, want_metric=True)
res = []
for s in self._local_samples:
tmp = self._hamiltonian(lin+s).metric(x)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-s).metric(x)
res.append(tmp)
return utilities.allreduce_sum(res, self._comm)/self._n_eff_samples
return utilities.allreduce_sum(res, self._comm)/self.n_eff_samples
@property
def n_eff_samples(self):
if self._mirror_samples:
return 2*self._n_samples
return self._n_samples
@property
def metric(self):
......@@ -216,9 +227,10 @@ class MetricGaussianKL(Energy):
yield -s
else:
rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
lo, _ = _get_lo_hi(self._comm, self._n_samples)
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
data = self._local_samples[i-self._lo] if rank == itask else None
data = self._local_samples[i-lo] if rank == itask else None
s = self._comm.bcast(data, root=itask)
yield s
if self._mirror_samples:
......@@ -231,7 +243,7 @@ class MetricGaussianKL(Energy):
' not take point_estimates into accout. Make sure that this '
'is your intended use.')
logger.warning(s)
lin = self._lin.with_want_metric()
lin = Linearization.make_var(self.position, True)
samp = []
sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._local_samples):
......@@ -240,4 +252,4 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
samp.append(tmp)
return utilities.allreduce_sum(samp, self._comm)/self._n_eff_samples
return utilities.allreduce_sum(samp, self._comm)/self.n_eff_samples
......@@ -248,6 +248,10 @@ class MultiField(Operator):
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
def extract_by_keys(self, keys):
dom = MultiDomain.make({kk: vv for kk, vv in self.domain.items() if kk in keys})
return self.extract(dom)
def extract_part(self, subset):
if subset is self._domain:
return self
......
......@@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta):
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None:
return None, self
dom = c_inp.domain
if isinstance(dom, MultiDomain) and len(dom) == 0:
return None, self
# Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain
if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain)
assert isinstance(dom, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
raise ValueError
if c_inp.domain is self.domain:
if dom is self.domain:
if isinstance(self, EnergyOperator):
op = ConstantEnergyOperator(self.domain, self(c_inp))
else:
op = ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
if not isinstance(c_inp.domain, MultiDomain):
if not isinstance(dom, MultiDomain):
raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp)
......
......@@ -520,7 +520,7 @@ def calculate_position(operator, output):
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3):
logger.info(f'Start iteration {ii+1}/3')
kl = MetricGaussianKL(pos, H, 3, mirror_samples=True)
kl = MetricGaussianKL.make(pos, H, 3, mirror_samples=True)
kl, _ = minimizer(kl)
pos = kl.position
return pos
......@@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args)
ift.MetricGaussianKL.make(**args)
return
kl = ift.MetricGaussianKL(**args)
kl = ift.MetricGaussianKL.make(**args)
assert_(len(ic.history) > 0)
assert_(len(ic.history) == len(ic.history.time_stamps))
assert_(len(ic.history) == len(ic.history.energy_values))
......@@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
assert_(len(ic.history) == len(ic.history.energy_values))
locsamp = kl._local_samples
klpure = ift.MetricGaussianKL(mean0,
h,
nsamps,
mirror_samples=mirror_samples,
constants=constants,
point_estimates=point_estimates,
_local_samples=locsamp)
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
klpure = ift.MetricGaussianKL(mean0, tmph, nsamps, mirror_samples, None, locsamp, False, True)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......
......@@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args, comm=comm)
ift.MetricGaussianKL.make(**args, comm=comm)
return
if mode == 0:
kl0 = ift.MetricGaussianKL(**args, comm=comm)
kl0 = ift.MetricGaussianKL.make(**args, comm=comm)
locsamp = kl0._local_samples
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True)
elif mode == 1:
kl0 = ift.MetricGaussianKL(**args)
kl0 = ift.MetricGaussianKL.make(**args)
samples = kl0._local_samples
ii = len(samples)//2
slc = slice(None, ii) if rank == 0 else slice(ii, None)
locsamp = samples[slc]
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......
......@@ -16,11 +16,15 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import pytest
from numpy.testing import assert_equal
import nifty7 as ift
from .common import setup_function, teardown_function
pmp = pytest.mark.parametrize
def test_get_signal_variance():
space = ift.RGSpace(3)
......@@ -45,15 +49,13 @@ def test_exec_time():
lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op1
ic = ift.GradientNormController(iteration_limit=2)
ham = ift.StandardHamiltonian(lh, ic_samp=ic)
kl = ift.MetricGaussianKL(ift.full(ham.domain, 0.), ham, 1)
kl = ift.MetricGaussianKL.make(ift.full(ham.domain, 0.), ham, 1)
ops = [op, op1, lh, ham, kl]
for oo in ops:
for wm in [True, False]:
ift.exec_time(oo, wm)
import pytest
pmp = pytest.mark.parametrize
@pmp('mf', [False, True])
@pmp('cplx', [False, True])
def test_calc_pos(mf, cplx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment