Commit eded7903 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try to unify KL classes and make them reproducible independent of number of MPI tasks

parent f82b084e
Pipeline #71226 passed with stages
in 27 minutes and 37 seconds
......@@ -11,17 +11,64 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .. import utilities
from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import makeOp
from ..operators.endomorphic_operator import EndomorphicOperator
from .energy import Energy
import numpy as np
from ..probing import approximation2endo
from ..sugar import makeOp, full
from ..field import Field
from ..multi_field import MultiField
from .. import random
def _shareRange(nwork, nshares, myshare):
nbase = nwork//nshares
additional = nwork % nshares
lo = myshare*nbase + min(myshare, additional)
hi = lo + nbase + int(myshare < additional)
return lo, hi
def np_allreduce_sum(comm, arr):
if comm is None:
return arr
from mpi4py import MPI
arr = np.array(arr)
res = np.empty_like(arr)
comm.Allreduce(arr, res, MPI.SUM)
return res
def allreduce_sum_field(comm, fld):
if comm is None:
return fld
if isinstance(fld, Field):
return Field(fld.domain, np_allreduce_sum(fld.val))
res = tuple(
Field(f.domain, np_allreduce_sum(comm, f.val))
for f in fld.values())
return MultiField(fld.domain, res)
class KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = KL.position.domain
def apply(self, x, mode):
self._check_input(x, mode)
return self._KL.apply_metric(x)
def draw_sample(self, from_inverse=False, dtype=np.float64):
return self._KL.metric_sample(from_inverse, dtype)
class MetricGaussianKL(Energy):
......@@ -38,6 +85,7 @@ class MetricGaussianKL(Energy):
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Parameters
----------
......@@ -62,9 +110,16 @@ class MetricGaussianKL(Energy):
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
use_mpi : bool
whether MPI should be used.
If MPI is enabled, samples will be distributed as evenly as possible
across MPI.COMM_WORLD. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
FIXME: lh_sampling_dtype not documented!
Note
----
The two lists `constants` and `point_estimates` are independent from each
......@@ -79,7 +134,8 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, _samples=None, lh_sampling_dtype=np.float64):
napprox=0, use_mpi=False, _samples=None,
lh_sampling_dtype=np.float64):
super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
......@@ -88,47 +144,75 @@ class MetricGaussianKL(Energy):
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = list(constants)
self._point_estimates = list(point_estimates)
self._constants = tuple(constants)
self._point_estimates = tuple(point_estimates)
if not isinstance(mirror_samples, bool):
raise TypeError
self._hamiltonian = hamiltonian
self._use_mpi = bool(use_mpi)
if self._use_mpi:
from mpi4py import MPI
self._comm = MPI.COMM_WORLD
self._ntask = self._comm.Get_size()
self._rank = self._comm.Get_rank()
self._master = (self._rank == 0)
else:
self._comm, self._ntask, self._rank, self._master = None, 1, 0, True
self._n_samples = int(n_samples)
self._lo, self._hi = _shareRange(self._n_samples, self._ntask, self._rank)
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
if self._mirror_samples:
self._n_eff_samples *= 2
if _samples is None:
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
mean, self._point_estimates, True)).metric
# FIXME: should this be ">=1" instead of ">1"?
if napprox > 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = tuple(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype)
for _ in range(n_samples))
if mirror_samples:
_samples += tuple(-s for s in _samples)
_samples = []
sseq = random.spawn_sseq(self._n_samples)
for i in range(self._lo, self._hi):
random.push_sseq(sseq[i])
_samples.append(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
random.pop_sseq()
_samples = tuple(_samples)
else:
if len(_samples) != self._n_samples:
raise ValueError("# of samples mismatch")
self._samples = _samples
# FIXME Use simplify for constant input instead
self._lin = Linearization.make_partial_var(mean, constants)
self._lin = Linearization.make_partial_var(mean, self._constants)
v, g = None, None
for s in self._samples:
tmp = self._hamiltonian(self._lin+s)
if v is None:
v = tmp.val.val[()]
g = tmp.gradient
else:
v += tmp.val.val[()]
g = g + tmp.gradient
self._val = v / len(self._samples)
self._grad = g * (1./len(self._samples))
if len(self._samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.val
g = 0. * tmp.gradient
else:
for s in self._samples:
tmp = self._hamiltonian(self._lin+s)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s)
if v is None:
v = tmp.val.val_rw()
g = tmp.gradient
else:
v += tmp.val.val
g = g + tmp.gradient
self._val = np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
self._grad = allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None
self._napprox = napprox
self._sampdt = lh_sampling_dtype
def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._point_estimates,
napprox=self._napprox, _samples=self._samples,
lh_sampling_dtype=self._sampdt)
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, use_mpi=self._use_mpi,
_samples=self._samples, lh_sampling_dtype=self._sampdt)
@property
def value(self):
......@@ -139,30 +223,48 @@ class MetricGaussianKL(Energy):
return self._grad
def _get_metric(self):
lin = self._lin.with_want_metric()
if self._metric is None:
lin = self._lin.with_want_metric()
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self._unscaled_metric = utilities.my_sum(mymap)
self._metric = self._unscaled_metric.scale(1./len(self._samples))
def unscaled_metric(self):
self._get_metric()
return self._unscaled_metric, 1/len(self._samples)
if len(self._samples) == 0: # hack if there are too many MPI tasks
self._metric = self._hamiltonian(lin).metric.scale(0.)
else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self.unscaled_metric = utilities.my_sum(mymap)
self._metric = self.unscaled_metric.scale(1./self._n_eff_samples)
def apply_metric(self, x):
self._get_metric()
return self._metric(x)
return allreduce_sum_field(self._comm, self._metric(x))
@property
def metric(self):
self._get_metric()
return self._metric
return KLMetric(self)
@property
def samples(self):
return self._samples
if self._comm is not None:
res = _comm.allgather(self._samples)
res = [item for sublist in res for item in sublist]
else:
res = self._samples
if self._mirror_samples:
res = res + tuple(-item for item in res)
return res
def unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise NotImplementedError()
lin = self._lin.with_want_metric()
samp = full(self._hamiltonian.domain, 0.)
sseq = random.spawn_sseq(n_samples)
for i, v in enumerate(self._samples):
random.push_sseq(sseq[self._lo+i])
samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype)
if self._mirror_samples:
samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype)
random.pop_sseq()
return allreduce_sum_field(self._comm, samp)
def __repr__(self):
return 'KL ({} samples):\n'.format(len(
self._samples)) + utilities.indent(self._hamiltonian.__repr__())
def metric_sample(self, from_inverse=False, dtype=np.float64):
return self.unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .. import utilities
from ..linearization import Linearization
from ..operators.energy_operators import StandardHamiltonian
from ..operators.endomorphic_operator import EndomorphicOperator
from .energy import Energy
from mpi4py import MPI
import numpy as np
from ..probing import approximation2endo
from ..sugar import makeOp, full
from ..field import Field
from ..multi_field import MultiField
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
rank = _comm.Get_rank()
master = (rank == 0)
def _shareRange(nwork, nshares, myshare):
nbase = nwork//nshares
additional = nwork % nshares
lo = myshare*nbase + min(myshare, additional)
hi = lo + nbase + int(myshare < additional)
return lo, hi
def np_allreduce_sum(arr):
arr = np.array(arr)
res = np.empty_like(arr)
_comm.Allreduce(arr, res, MPI.SUM)
return res
def allreduce_sum_field(fld):
if isinstance(fld, Field):
return Field(fld.domain, np_allreduce_sum(fld.val))
res = tuple(
Field(f.domain, np_allreduce_sum(f.val))
for f in fld.values())
return MultiField(fld.domain, res)
class KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = KL.position.domain
def apply(self, x, mode):
self._check_input(x, mode)
return self._KL.apply_metric(x)
def draw_sample(self, from_inverse=False, dtype=np.float64):
return self._KL.metric_sample(from_inverse, dtype)
class MetricGaussianKL_MPI(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
A Metric Gaussian is used to approximate another probability distribution.
It is a Gaussian distribution that uses the Fisher information metric of
the other distribution at the location of its mean to approximate the
variance. In order to infer the mean, a stochastic estimate of the
Kullback-Leibler divergence is minimized. This estimate is obtained by
sampling the Metric Gaussian at the current mean. During minimization
these samples are kept constant; only the mean is updated. Due to the
typically nonlinear structure of the true distribution these samples have
to be updated eventually by intantiating `MetricGaussianKL` again. For the
true probability distribution the standard parametrization is assumed.
The samples of this class are distributed among MPI tasks.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, _samples=None, lh_sampling_dtype=np.float64):
super(MetricGaussianKL_MPI, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = list(constants)
self._point_estimates = list(point_estimates)
if not isinstance(mirror_samples, bool):
raise TypeError
self._hamiltonian = hamiltonian
if _samples is None:
if mirror_samples:
lo, hi = _shareRange(n_samples*2, ntask, rank)
else:
lo, hi = _shareRange(n_samples, ntask, rank)
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
if napprox > 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = []
sseq = random.spawn_sseq(n_samples)
for i in range(lo, hi):
if mirror_samples:
random.push_sseq(sseq[i//2])
if (i % 2) and (i-1 >= lo):
_samples.append(-_samples[-1])
else:
_samples.append(((i % 2)*2-1) *
met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
random.pop_sseq()
else:
random.push_sseq(sseq[i])
_samples.append(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
random.pop_sseq()
_samples = tuple(_samples)
if mirror_samples:
n_samples *= 2
self._samples = _samples
self._n_samples = n_samples
self._lin = Linearization.make_partial_var(mean, constants)
v, g = None, None
if len(self._samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.val
g = 0. * tmp.gradient
else:
for s in self._samples:
tmp = self._hamiltonian(self._lin+s)
if v is None:
v = tmp.val.val_rw()
g = tmp.gradient
else:
v += tmp.val.val
g = g + tmp.gradient
self._val = np_allreduce_sum(v)[()] / self._n_samples
self._grad = allreduce_sum_field(g) / self._n_samples
self._metric = None
self._sampdt = lh_sampling_dtype
def at(self, position):
return MetricGaussianKL_MPI(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, _samples=self._samples,
lh_sampling_dtype=self._sampdt)
@property
def value(self):
return self._val
@property
def gradient(self):
return self._grad
def _get_metric(self):
lin = self._lin.with_want_metric()
if self._metric is None:
if len(self._samples) == 0: # hack if there are too many MPI tasks
self._metric = self._hamiltonian(lin).metric.scale(0.)
else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self.unscaled_metric = utilities.my_sum(mymap)
self._metric = self.unscaled_metric.scale(1./self._n_samples)
def apply_metric(self, x):
self._get_metric()
return allreduce_sum_field(self._metric(x))
@property
def metric(self):
return KLMetric(self)
@property
def samples(self):
res = _comm.allgather(self._samples)
res = [item for sublist in res for item in sublist]
return res
def unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise NotImplementedError()
lin = self._lin.with_want_metric()
samp = full(self._hamiltonian.domain, 0.)
sseq = random.spawn_sseq(n_samples)
for i, v in enumerate(self._samples):
# FIXME: this is not yet correct. We need to use the _global_ sample
# index, not the local one!
random.push_sseq(sseq[i])
samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype)
random.pop_sseq()
return allreduce_sum_field(samp)
def metric_sample(self, from_inverse=False, dtype=np.float64):
return self.unscaled_metric_sample(from_inverse, dtype)/self._n_samples
from .minimization.metric_gaussian_kl_mpi import MetricGaussianKL_MPI
......@@ -45,12 +45,13 @@ def test_kl(constants, point_estimates, mirror_samples):
point_estimates=point_estimates,
mirror_samples=mirror_samples,
napprox=0)
samp_full = kl.samples
klpure = ift.MetricGaussianKL(mean0,
h,
nsamps,
mirror_samples=mirror_samples,
len(samp_full),
mirror_samples=False,
napprox=0,
_samples=kl.samples)
_samples=samp_full)
# Test value
assert_allclose(kl.value, klpure.value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment