Commit 37047b33 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add maker method to MetricGaussianKL

parent fb0a4dc3
Pipeline #76915 passed with stages
in 11 minutes and 52 seconds
Changes since NIFTy 6 Changes since NIFTy 6
===================== =====================
*None.* MetricGaussianKL interface
--------------------------
Users do not instanciate `MetricGaussianKL` by its constructor anymore. Rather
`MetricGaussianKL.make()` shall be used.
Changes since NIFTy 5 Changes since NIFTy 5
......
...@@ -131,7 +131,7 @@ def main(): ...@@ -131,7 +131,7 @@ def main():
# Draw new samples to approximate the KL five times # Draw new samples to approximate the KL five times
for i in range(5): for i in range(5):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples) KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
mean = KL.position mean = KL.position
...@@ -144,7 +144,7 @@ def main(): ...@@ -144,7 +144,7 @@ def main():
name=filename.format("loop_{:02d}".format(i))) name=filename.format("loop_{:02d}".format(i)))
# Draw posterior samples # Draw posterior samples
KL = ift.MetricGaussianKL(mean, H, N_samples) KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator() sc = ift.StatCalculator()
for sample in KL.samples: for sample in KL.samples:
sc.add(signal(sample + KL.position)) sc.add(signal(sample + KL.position))
......
...@@ -131,7 +131,7 @@ def main(): ...@@ -131,7 +131,7 @@ def main():
for i in range(10): for i in range(10):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples) KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
mean = KL.position mean = KL.position
...@@ -157,7 +157,7 @@ def main(): ...@@ -157,7 +157,7 @@ def main():
name=filename.format("loop_{:02d}".format(i))) name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples # Done, draw posterior samples
KL = ift.MetricGaussianKL(mean, H, N_samples) KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator() sc = ift.StatCalculator()
scA1 = ift.StatCalculator() scA1 = ift.StatCalculator()
scA2 = ift.StatCalculator() scA2 = ift.StatCalculator()
......
...@@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm ...@@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm
import nifty7 as ift import nifty7 as ift
def main(): def main():
dom = ift.UnstructuredDomain(1) dom = ift.UnstructuredDomain(1)
scale = 10 scale = 10
...@@ -90,7 +91,7 @@ def main(): ...@@ -90,7 +91,7 @@ def main():
plt.figure(figsize=[12, 8]) plt.figure(figsize=[12, 8])
for ii in range(15): for ii in range(15):
if ii % 3 == 0: if ii % 3 == 0:
mgkl = ift.MetricGaussianKL(pos, ham, 40) mgkl = ift.MetricGaussianKL.make(pos, ham, 40)
plt.cla() plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3, plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
......
...@@ -24,7 +24,7 @@ from ..multi_field import MultiField ...@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo from ..probing import approximation2endo
from ..sugar import makeDomain, makeOp from ..sugar import makeOp
from .energy import Energy from .energy import Energy
...@@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator): ...@@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator):
return self._KL._metric_sample(from_inverse) return self._KL._metric_sample(from_inverse)
def _get_lo_hi(comm, n_samples):
ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
return utilities.shareRange(n_samples, ntask, rank)
class MetricGaussianKL(Energy): class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution """Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian. and a Metric Gaussian.
...@@ -58,58 +63,91 @@ class MetricGaussianKL(Energy): ...@@ -58,58 +63,91 @@ class MetricGaussianKL(Energy):
true probability distribution the standard parametrization is assumed. true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks. The samples of this class can be distributed among MPI tasks.
Parameters Notes
---------- -----
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also See also
-------- --------
`Metric Gaussian Variational Inference`, Jakob Knollmüller, `Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
""" """
def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm,
def __init__(self, mean, hamiltonian, n_samples, constants=[], local_samples, nanisinf, _callingfrommake=False):
point_estimates=[], mirror_samples=False, if not _callingfrommake:
napprox=0, comm=None, _local_samples=None, raise NotImplementedError
nanisinf=False):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
self._mirror_samples = bool(mirror_samples)
self._comm = comm
self._local_samples = local_samples
self._nanisinf = bool(nanisinf)
lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
tmp = hamiltonian(lin+s)
tv = tmp.val.val
tg = tmp.gradient
if mirror_samples:
tmp = hamiltonian(lin-s)
tv = tv + tmp.val.val
tg = tg + tmp.gradient
v.append(tv)
g.append(tg)
self._val = utilities.allreduce_sum(v, self._comm)[()]/self.n_eff_samples
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples
@staticmethod
def make(mean, hamiltonian, n_samples, constants=[], point_estimates=[],
mirror_samples=False, napprox=0, comm=None, nanisinf=False):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
"""
if not isinstance(hamiltonian, StandardHamiltonian): if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError raise TypeError
...@@ -117,72 +155,39 @@ class MetricGaussianKL(Energy): ...@@ -117,72 +155,39 @@ class MetricGaussianKL(Energy):
raise ValueError raise ValueError
if not isinstance(n_samples, int): if not isinstance(n_samples, int):
raise TypeError raise TypeError
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool): if not isinstance(mirror_samples, bool):
raise TypeError raise TypeError
if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()): if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
raise RuntimeError( raise RuntimeError(
'Point estimates for whole domain. Use EnergyAdapter instead.') 'Point estimates for whole domain. Use EnergyAdapter instead.')
n_samples = int(n_samples)
mirror_samples = bool(mirror_samples)
self._hamiltonian = hamiltonian if isinstance(mean, MultiField):
if len(constants) > 0: cstpos = mean.extract_by_keys(point_estimates)
dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants} _, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
self._n_samples = int(n_samples)
self._comm = comm
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
if self._mirror_samples:
self._n_eff_samples *= 2
if _local_samples is None:
if len(point_estimates) > 0:
dom = {kk: vv for kk, vv in mean.domain.items()
if kk in point_estimates}
dom = makeDomain(dom)
cstpos = mean.extract(dom)
_, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
met = hamiltonian(Linearization.make_var(mean, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_local_samples = []
sseq = random.spawn_sseq(self._n_samples)
for i in range(self._lo, self._hi):
with random.Context(sseq[i]):
_local_samples.append(met.draw_sample(from_inverse=True))
_local_samples = tuple(_local_samples)
else: else:
if len(_local_samples) != self._hi-self._lo: ham_sampling = hamiltonian
raise ValueError("# of samples mismatch") met = ham_sampling(Linearization.make_var(mean, True)).metric
self._local_samples = _local_samples if napprox >= 1:
self._lin = Linearization.make_var(mean) met._approximation = makeOp(approximation2endo(met, napprox))
v, g = [], [] local_samples = []
for s in self._local_samples: sseq = random.spawn_sseq(n_samples)
tmp = self._hamiltonian(self._lin+s) for i in range(*_get_lo_hi(comm, n_samples)):
tv = tmp.val.val with random.Context(sseq[i]):
tg = tmp.gradient local_samples.append(met.draw_sample(from_inverse=True))
if self._mirror_samples: local_samples = tuple(local_samples)
tmp = self._hamiltonian(self._lin-s)
tv = tv + tmp.val.val if isinstance(mean, MultiField):
tg = tg + tmp.gradient _, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants))
v.append(tv) return MetricGaussianKL(
g.append(tg) mean, hamiltonian, n_samples, mirror_samples, comm, local_samples,
self._val = utilities.allreduce_sum(v, self._comm)[()]/self._n_eff_samples nanisinf, _callingfrommake=True)
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = utilities.allreduce_sum(g, self._comm)/self._n_eff_samples
def at(self, position): def at(self, position):
return MetricGaussianKL( return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, position, self._hamiltonian, self._n_samples, self._mirror_samples,
mirror_samples=self._mirror_samples, comm=self._comm, self._comm, self._local_samples, self._nanisinf, True)
_local_samples=self._local_samples, nanisinf=self._mitigate_nans)
@property @property
def value(self): def value(self):
...@@ -193,14 +198,20 @@ class MetricGaussianKL(Energy): ...@@ -193,14 +198,20 @@ class MetricGaussianKL(Energy):
return self._grad return self._grad
def apply_metric(self, x): def apply_metric(self, x):
lin = self._lin.with_want_metric() lin = Linearization.make_var(self.position, want_metric=True)
res = [] res = []
for s in self._local_samples: for s in self._local_samples:
tmp = self._hamiltonian(lin+s).metric(x) tmp = self._hamiltonian(lin+s).metric(x)
if self._mirror_samples: if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-s).metric(x) tmp = tmp + self._hamiltonian(lin-s).metric(x)
res.append(tmp) res.append(tmp)
return utilities.allreduce_sum(res, self._comm)/self._n_eff_samples return utilities.allreduce_sum(res, self._comm)/self.n_eff_samples
@property
def n_eff_samples(self):
if self._mirror_samples:
return 2*self._n_samples
return self._n_samples
@property @property
def metric(self): def metric(self):
...@@ -216,9 +227,10 @@ class MetricGaussianKL(Energy): ...@@ -216,9 +227,10 @@ class MetricGaussianKL(Energy):
yield -s yield -s
else: else:
rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)] rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
lo, _ = _get_lo_hi(self._comm, self._n_samples)
for itask, (l, h) in enumerate(rank_lo_hi): for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h): for i in range(l, h):
data = self._local_samples[i-self._lo] if rank == itask else None data = self._local_samples[i-lo] if rank == itask else None
s = self._comm.bcast(data, root=itask) s = self._comm.bcast(data, root=itask)
yield s yield s
if self._mirror_samples: if self._mirror_samples:
...@@ -231,7 +243,7 @@ class MetricGaussianKL(Energy): ...@@ -231,7 +243,7 @@ class MetricGaussianKL(Energy):
' not take point_estimates into accout. Make sure that this ' ' not take point_estimates into accout. Make sure that this '
'is your intended use.') 'is your intended use.')
logger.warning(s) logger.warning(s)
lin = self._lin.with_want_metric() lin = Linearization.make_var(self.position, True)
samp = [] samp = []
sseq = random.spawn_sseq(self._n_samples) sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._local_samples): for i, v in enumerate(self._local_samples):
...@@ -240,4 +252,4 @@ class MetricGaussianKL(Energy): ...@@ -240,4 +252,4 @@ class MetricGaussianKL(Energy):
if self._mirror_samples: if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False) tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
samp.append(tmp) samp.append(tmp)
return utilities.allreduce_sum(samp, self._comm)/self._n_eff_samples return utilities.allreduce_sum(samp, self._comm)/self.n_eff_samples
...@@ -248,6 +248,10 @@ class MultiField(Operator): ...@@ -248,6 +248,10 @@ class MultiField(Operator):
return MultiField(subset, return MultiField(subset,
tuple(self[key] for key in subset.keys())) tuple(self[key] for key in subset.keys()))
def extract_by_keys(self, keys):
dom = MultiDomain.make({kk: vv for kk, vv in self.domain.items() if kk in keys})
return self.extract(dom)
def extract_part(self, subset): def extract_part(self, subset):
if subset is self._domain: if subset is self._domain:
return self return self
......
...@@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta): ...@@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta):
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None: if c_inp is None:
return None, self return None, self
dom = c_inp.domain
if isinstance(dom, MultiDomain) and len(dom) == 0:
return None, self
# Convention: If c_inp is MultiField, it needs to be defined on a # Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain # subdomain of self._domain
if isinstance(self.domain, MultiDomain): if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain) assert isinstance(dom, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()): if set(c_inp.keys()) > set(self.domain.keys()):
raise ValueError raise ValueError
if c_inp.domain is self.domain: if dom is self.domain:
if isinstance(self, EnergyOperator): if isinstance(self, EnergyOperator):
op = ConstantEnergyOperator(self.domain, self(c_inp)) op = ConstantEnergyOperator(self.domain, self(c_inp))
else: else:
op = ConstantOperator(self.domain, self(c_inp)) op = ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp)) op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op return op(c_inp), op
if not isinstance(c_inp.domain, MultiDomain): if not isinstance(dom, MultiDomain):
raise RuntimeError raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp) return self._simplify_for_constant_input_nontrivial(c_inp)
......
...@@ -520,7 +520,7 @@ def calculate_position(operator, output): ...@@ -520,7 +520,7 @@ def calculate_position(operator, output):
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos')) minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3): for ii in range(3):
logger.info(f'Start iteration {ii+1}/3') logger.info(f'Start iteration {ii+1}/3')
kl = MetricGaussianKL(pos, H, 3, mirror_samples=True) kl = MetricGaussianKL.make(pos, H, 3, mirror_samples=True)
kl, _ = minimizer(kl) kl, _ = minimizer(kl)
pos = kl.position pos = kl.position
return pos return pos
...@@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf): ...@@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
'hamiltonian': h} 'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()): if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError): with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args) ift.MetricGaussianKL.make(**args)
return return
kl = ift.MetricGaussianKL(**args) kl = ift.MetricGaussianKL.make(**args)
assert_(len(ic.history) > 0) assert_(len(ic.history) > 0)
assert_(len(ic.history) == len(ic.history.time_stamps)) assert_(len(ic.history) == len(ic.history.time_stamps))
assert_(len(ic.history) == len(ic.history.energy_values)) assert_(len(ic.history) == len(ic.history.energy_values))
...@@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf): ...@@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
assert_(len(ic.history) == len(ic.history.energy_values)) assert_(len(ic.history) == len(ic.history.energy_values))
locsamp = kl._local_samples locsamp = kl._local_samples
klpure = ift.MetricGaussianKL(mean0, if isinstance(mean0, ift.MultiField):
h, _, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
nsamps, else:
mirror_samples=mirror_samples, tmph = h
constants=constants, klpure = ift.MetricGaussianKL(mean0, tmph, nsamps, mirror_samples, None, locsamp, False, True)
point_estimates=point_estimates,
_local_samples=locsamp)
# Test number of samples # Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps expected_nsamps = 2*nsamps if mirror_samples else nsamps
......
...@@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf): ...@@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
'hamiltonian': h} 'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()): if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError): with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args, comm=comm) ift.MetricGaussianKL.make(**args, comm=comm)
return return
if mode == 0: if mode == 0:
kl0 = ift.MetricGaussianKL(**args, comm=comm) kl0 = ift.MetricGaussianKL.make(**args, comm=comm)
locsamp = kl0._local_samples locsamp = kl0._local_samples
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp) if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True)
elif mode == 1: elif mode == 1:
kl0 = ift.MetricGaussianKL(**args) kl0 = ift.MetricGaussianKL.make(**args)
samples = kl0._local_samples samples = kl0._local_samples