diff --git a/ChangeLog.md b/ChangeLog.md index f82f66a88825a7d9892eecd864fe5cb09f7490f7..576a17ca4ef7bdddbfcc7bb656721d526bda797f 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -1,7 +1,11 @@ Changes since NIFTy 6 ===================== -*None.* +MetricGaussianKL interface +-------------------------- + +Users do not instanciate `MetricGaussianKL` by its constructor anymore. Rather +`MetricGaussianKL.make()` shall be used. Changes since NIFTy 5 diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index 18a0a4c194c2dcb49855d7830d4c4adc0a43088c..21f90e563ab9bcbdd8852496843f69a4bd8a8df9 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -131,7 +131,7 @@ def main(): # Draw new samples to approximate the KL five times for i in range(5): # Draw new samples and minimize KL - KL = ift.MetricGaussianKL(mean, H, N_samples) + KL = ift.MetricGaussianKL.make(mean, H, N_samples) KL, convergence = minimizer(KL) mean = KL.position @@ -144,7 +144,7 @@ def main(): name=filename.format("loop_{:02d}".format(i))) # Draw posterior samples - KL = ift.MetricGaussianKL(mean, H, N_samples) + KL = ift.MetricGaussianKL.make(mean, H, N_samples) sc = ift.StatCalculator() for sample in KL.samples: sc.add(signal(sample + KL.position)) diff --git a/demos/getting_started_5_mf.py b/demos/getting_started_5_mf.py index c52269cbb0f8ed2a5675d27344354ad4234503d9..84d9b77cd4e6b6347c00ca67a0ea5e1d22c688ee 100644 --- a/demos/getting_started_5_mf.py +++ b/demos/getting_started_5_mf.py @@ -131,7 +131,7 @@ def main(): for i in range(10): # Draw new samples and minimize KL - KL = ift.MetricGaussianKL(mean, H, N_samples) + KL = ift.MetricGaussianKL.make(mean, H, N_samples) KL, convergence = minimizer(KL) mean = KL.position @@ -157,7 +157,7 @@ def main(): name=filename.format("loop_{:02d}".format(i))) # Done, draw posterior samples - KL = ift.MetricGaussianKL(mean, H, N_samples) + KL = ift.MetricGaussianKL.make(mean, H, N_samples) sc = ift.StatCalculator() scA1 = ift.StatCalculator() scA2 = ift.StatCalculator() diff --git a/demos/mgvi_visualized.py b/demos/mgvi_visualized.py index a64114ffdf33518f771ae70257ea2e17f5850a14..69e0373ede3aed5895ee51a87132e9fb9536dca0 100644 --- a/demos/mgvi_visualized.py +++ b/demos/mgvi_visualized.py @@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm import nifty7 as ift + def main(): dom = ift.UnstructuredDomain(1) scale = 10 @@ -90,7 +91,7 @@ def main(): plt.figure(figsize=[12, 8]) for ii in range(15): if ii % 3 == 0: - mgkl = ift.MetricGaussianKL(pos, ham, 40) + mgkl = ift.MetricGaussianKL.make(pos, ham, 40) plt.cla() plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3, diff --git a/src/minimization/metric_gaussian_kl.py b/src/minimization/metric_gaussian_kl.py index 83f769d579d9a8232610f148d35e8f6e4ae26520..77bba50ab148db75f536ebaff6cc97c2db9780e6 100644 --- a/src/minimization/metric_gaussian_kl.py +++ b/src/minimization/metric_gaussian_kl.py @@ -24,7 +24,7 @@ from ..multi_field import MultiField from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.energy_operators import StandardHamiltonian from ..probing import approximation2endo -from ..sugar import makeDomain, makeOp +from ..sugar import makeOp from .energy import Energy @@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator): return self._KL._metric_sample(from_inverse) +def _get_lo_hi(comm, n_samples): + ntask, rank, _ = utilities.get_MPI_params_from_comm(comm) + return utilities.shareRange(n_samples, ntask, rank) + + class MetricGaussianKL(Energy): """Provides the sampled Kullback-Leibler divergence between a distribution and a Metric Gaussian. @@ -58,58 +63,91 @@ class MetricGaussianKL(Energy): true probability distribution the standard parametrization is assumed. The samples of this class can be distributed among MPI tasks. - Parameters - ---------- - mean : Field - Mean of the Gaussian probability distribution. - hamiltonian : StandardHamiltonian - Hamiltonian of the approximated probability distribution. - n_samples : integer - Number of samples used to stochastically estimate the KL. - constants : list - List of parameter keys that are kept constant during optimization. - Default is no constants. - point_estimates : list - List of parameter keys for which no samples are drawn, but that are - (possibly) optimized for, corresponding to point estimates of these. - Default is to draw samples for the complete domain. - mirror_samples : boolean - Whether the negative of the drawn samples are also used, - as they are equally legitimate samples. If true, the number of used - samples doubles. Mirroring samples stabilizes the KL estimate as - extreme sample variation is counterbalanced. Default is False. - napprox : int - Number of samples for computing preconditioner for sampling. No - preconditioning is done by default. - comm : MPI communicator or None - If not None, samples will be distributed as evenly as possible - across this communicator. If `mirror_samples` is set, then a sample and - its mirror image will always reside on the same task. - nanisinf : bool - If true, nan energies which can happen due to overflows in the forward - model are interpreted as inf. Thereby, the code does not crash on - these occaisions but rather the minimizer is told that the position it - has tried is not sensible. - _local_samples : None - Only a parameter for internal uses. Typically not to be set by users. - - Note - ---- - The two lists `constants` and `point_estimates` are independent from each - other. It is possible to sample along domains which are kept constant - during minimization and vice versa. + Notes + ----- + DomainTuples should never be created using the constructor, but rather + via the factory function :attr:`make`! See also -------- `Metric Gaussian Variational Inference`, Jakob Knollmüller, Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_ """ - - def __init__(self, mean, hamiltonian, n_samples, constants=[], - point_estimates=[], mirror_samples=False, - napprox=0, comm=None, _local_samples=None, - nanisinf=False): + def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm, + local_samples, nanisinf, _callingfrommake=False): + if not _callingfrommake: + raise NotImplementedError super(MetricGaussianKL, self).__init__(mean) + self._hamiltonian = hamiltonian + self._n_samples = int(n_samples) + self._mirror_samples = bool(mirror_samples) + self._comm = comm + self._local_samples = local_samples + self._nanisinf = bool(nanisinf) + + lin = Linearization.make_var(mean) + v, g = [], [] + for s in self._local_samples: + tmp = hamiltonian(lin+s) + tv = tmp.val.val + tg = tmp.gradient + if mirror_samples: + tmp = hamiltonian(lin-s) + tv = tv + tmp.val.val + tg = tg + tmp.gradient + v.append(tv) + g.append(tg) + self._val = utilities.allreduce_sum(v, self._comm)[()]/self.n_eff_samples + if np.isnan(self._val) and self._nanisinf: + self._val = np.inf + self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples + + @staticmethod + def make(mean, hamiltonian, n_samples, constants=[], point_estimates=[], + mirror_samples=False, napprox=0, comm=None, nanisinf=False): + """Return instance of :class:`MetricGaussianKL`. + + Parameters + ---------- + mean : Field + Mean of the Gaussian probability distribution. + hamiltonian : StandardHamiltonian + Hamiltonian of the approximated probability distribution. + n_samples : integer + Number of samples used to stochastically estimate the KL. + constants : list + List of parameter keys that are kept constant during optimization. + Default is no constants. + point_estimates : list + List of parameter keys for which no samples are drawn, but that are + (possibly) optimized for, corresponding to point estimates of these. + Default is to draw samples for the complete domain. + mirror_samples : boolean + Whether the negative of the drawn samples are also used, + as they are equally legitimate samples. If true, the number of used + samples doubles. Mirroring samples stabilizes the KL estimate as + extreme sample variation is counterbalanced. Default is False. + napprox : int + Number of samples for computing preconditioner for sampling. No + preconditioning is done by default. + comm : MPI communicator or None + If not None, samples will be distributed as evenly as possible + across this communicator. If `mirror_samples` is set, then a sample and + its mirror image will always reside on the same task. + nanisinf : bool + If true, nan energies which can happen due to overflows in the forward + model are interpreted as inf. Thereby, the code does not crash on + these occaisions but rather the minimizer is told that the position it + has tried is not sensible. + _local_samples : None + Only a parameter for internal uses. Typically not to be set by users. + + Note + ---- + The two lists `constants` and `point_estimates` are independent from each + other. It is possible to sample along domains which are kept constant + during minimization and vice versa. + """ if not isinstance(hamiltonian, StandardHamiltonian): raise TypeError @@ -117,72 +155,39 @@ class MetricGaussianKL(Energy): raise ValueError if not isinstance(n_samples, int): raise TypeError - self._mitigate_nans = nanisinf if not isinstance(mirror_samples, bool): raise TypeError if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()): raise RuntimeError( 'Point estimates for whole domain. Use EnergyAdapter instead.') + n_samples = int(n_samples) + mirror_samples = bool(mirror_samples) - self._hamiltonian = hamiltonian - if len(constants) > 0: - dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants} - dom = makeDomain(dom) - cstpos = mean.extract(dom) - _, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos) - - self._n_samples = int(n_samples) - self._comm = comm - ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm) - self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank) - - self._mirror_samples = bool(mirror_samples) - self._n_eff_samples = self._n_samples - if self._mirror_samples: - self._n_eff_samples *= 2 - - if _local_samples is None: - if len(point_estimates) > 0: - dom = {kk: vv for kk, vv in mean.domain.items() - if kk in point_estimates} - dom = makeDomain(dom) - cstpos = mean.extract(dom) - _, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos) - met = hamiltonian(Linearization.make_var(mean, True)).metric - if napprox >= 1: - met._approximation = makeOp(approximation2endo(met, napprox)) - _local_samples = [] - sseq = random.spawn_sseq(self._n_samples) - for i in range(self._lo, self._hi): - with random.Context(sseq[i]): - _local_samples.append(met.draw_sample(from_inverse=True)) - _local_samples = tuple(_local_samples) + if isinstance(mean, MultiField): + cstpos = mean.extract_by_keys(point_estimates) + _, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos) else: - if len(_local_samples) != self._hi-self._lo: - raise ValueError("# of samples mismatch") - self._local_samples = _local_samples - self._lin = Linearization.make_var(mean) - v, g = [], [] - for s in self._local_samples: - tmp = self._hamiltonian(self._lin+s) - tv = tmp.val.val - tg = tmp.gradient - if self._mirror_samples: - tmp = self._hamiltonian(self._lin-s) - tv = tv + tmp.val.val - tg = tg + tmp.gradient - v.append(tv) - g.append(tg) - self._val = utilities.allreduce_sum(v, self._comm)[()]/self._n_eff_samples - if np.isnan(self._val) and self._mitigate_nans: - self._val = np.inf - self._grad = utilities.allreduce_sum(g, self._comm)/self._n_eff_samples + ham_sampling = hamiltonian + met = ham_sampling(Linearization.make_var(mean, True)).metric + if napprox >= 1: + met._approximation = makeOp(approximation2endo(met, napprox)) + local_samples = [] + sseq = random.spawn_sseq(n_samples) + for i in range(*_get_lo_hi(comm, n_samples)): + with random.Context(sseq[i]): + local_samples.append(met.draw_sample(from_inverse=True)) + local_samples = tuple(local_samples) + + if isinstance(mean, MultiField): + _, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants)) + return MetricGaussianKL( + mean, hamiltonian, n_samples, mirror_samples, comm, local_samples, + nanisinf, _callingfrommake=True) def at(self, position): return MetricGaussianKL( - position, self._hamiltonian, self._n_samples, - mirror_samples=self._mirror_samples, comm=self._comm, - _local_samples=self._local_samples, nanisinf=self._mitigate_nans) + position, self._hamiltonian, self._n_samples, self._mirror_samples, + self._comm, self._local_samples, self._nanisinf, True) @property def value(self): @@ -193,14 +198,20 @@ class MetricGaussianKL(Energy): return self._grad def apply_metric(self, x): - lin = self._lin.with_want_metric() + lin = Linearization.make_var(self.position, want_metric=True) res = [] for s in self._local_samples: tmp = self._hamiltonian(lin+s).metric(x) if self._mirror_samples: tmp = tmp + self._hamiltonian(lin-s).metric(x) res.append(tmp) - return utilities.allreduce_sum(res, self._comm)/self._n_eff_samples + return utilities.allreduce_sum(res, self._comm)/self.n_eff_samples + + @property + def n_eff_samples(self): + if self._mirror_samples: + return 2*self._n_samples + return self._n_samples @property def metric(self): @@ -216,9 +227,10 @@ class MetricGaussianKL(Energy): yield -s else: rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)] + lo, _ = _get_lo_hi(self._comm, self._n_samples) for itask, (l, h) in enumerate(rank_lo_hi): for i in range(l, h): - data = self._local_samples[i-self._lo] if rank == itask else None + data = self._local_samples[i-lo] if rank == itask else None s = self._comm.bcast(data, root=itask) yield s if self._mirror_samples: @@ -231,7 +243,7 @@ class MetricGaussianKL(Energy): ' not take point_estimates into accout. Make sure that this ' 'is your intended use.') logger.warning(s) - lin = self._lin.with_want_metric() + lin = Linearization.make_var(self.position, True) samp = [] sseq = random.spawn_sseq(self._n_samples) for i, v in enumerate(self._local_samples): @@ -240,4 +252,4 @@ class MetricGaussianKL(Energy): if self._mirror_samples: tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False) samp.append(tmp) - return utilities.allreduce_sum(samp, self._comm)/self._n_eff_samples + return utilities.allreduce_sum(samp, self._comm)/self.n_eff_samples diff --git a/src/multi_field.py b/src/multi_field.py index 4208ab43e8a4570855290c4e605ba17271c02f47..bf352d33b260f05b22af583b1b69406f7f499fb3 100644 --- a/src/multi_field.py +++ b/src/multi_field.py @@ -248,6 +248,10 @@ class MultiField(Operator): return MultiField(subset, tuple(self[key] for key in subset.keys())) + def extract_by_keys(self, keys): + dom = MultiDomain.make({kk: vv for kk, vv in self.domain.items() if kk in keys}) + return self.extract(dom) + def extract_part(self, subset): if subset is self._domain: return self diff --git a/src/operators/operator.py b/src/operators/operator.py index 813858ce9ea31d6bb2802c88273ef2cea11006d1..d56472dafb97e6f132a5e1960045b4a3cc7b7c93 100644 --- a/src/operators/operator.py +++ b/src/operators/operator.py @@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta): from .simplify_for_const import ConstantEnergyOperator, ConstantOperator if c_inp is None: return None, self + dom = c_inp.domain + if isinstance(dom, MultiDomain) and len(dom) == 0: + return None, self # Convention: If c_inp is MultiField, it needs to be defined on a # subdomain of self._domain if isinstance(self.domain, MultiDomain): - assert isinstance(c_inp.domain, MultiDomain) + assert isinstance(dom, MultiDomain) if set(c_inp.keys()) > set(self.domain.keys()): raise ValueError - if c_inp.domain is self.domain: + if dom is self.domain: if isinstance(self, EnergyOperator): op = ConstantEnergyOperator(self.domain, self(c_inp)) else: op = ConstantOperator(self.domain, self(c_inp)) op = ConstantOperator(self.domain, self(c_inp)) return op(c_inp), op - if not isinstance(c_inp.domain, MultiDomain): + if not isinstance(dom, MultiDomain): raise RuntimeError return self._simplify_for_constant_input_nontrivial(c_inp) diff --git a/src/sugar.py b/src/sugar.py index 683b970af363b757bbceadd8ec4c5eb140705b2a..bac7704c7d540211066851821fb63abe36cb1471 100644 --- a/src/sugar.py +++ b/src/sugar.py @@ -520,7 +520,7 @@ def calculate_position(operator, output): minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos')) for ii in range(3): logger.info(f'Start iteration {ii+1}/3') - kl = MetricGaussianKL(pos, H, 3, mirror_samples=True) + kl = MetricGaussianKL.make(pos, H, 3, mirror_samples=True) kl, _ = minimizer(kl) pos = kl.position return pos diff --git a/test/test_kl.py b/test/test_kl.py index cee37c33a7b803b4ce40db14f7ae3d4ac557665b..b8f07a3f54101299adda6d2b8faf123f74a6dfa3 100644 --- a/test/test_kl.py +++ b/test/test_kl.py @@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf): 'hamiltonian': h} if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()): with assert_raises(RuntimeError): - ift.MetricGaussianKL(**args) + ift.MetricGaussianKL.make(**args) return - kl = ift.MetricGaussianKL(**args) + kl = ift.MetricGaussianKL.make(**args) assert_(len(ic.history) > 0) assert_(len(ic.history) == len(ic.history.time_stamps)) assert_(len(ic.history) == len(ic.history.energy_values)) @@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf): assert_(len(ic.history) == len(ic.history.energy_values)) locsamp = kl._local_samples - klpure = ift.MetricGaussianKL(mean0, - h, - nsamps, - mirror_samples=mirror_samples, - constants=constants, - point_estimates=point_estimates, - _local_samples=locsamp) + if isinstance(mean0, ift.MultiField): + _, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants)) + else: + tmph = h + klpure = ift.MetricGaussianKL(mean0, tmph, nsamps, mirror_samples, None, locsamp, False, True) # Test number of samples expected_nsamps = 2*nsamps if mirror_samples else nsamps diff --git a/test/test_mpi/test_kl.py b/test/test_mpi/test_kl.py index 7820498d1866e80b0333bb52b2bdc69fce3ef03d..f0eef5655cd42106c05c5ca79da554da35a2c58f 100644 --- a/test/test_mpi/test_kl.py +++ b/test/test_mpi/test_kl.py @@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf): 'hamiltonian': h} if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()): with assert_raises(RuntimeError): - ift.MetricGaussianKL(**args, comm=comm) + ift.MetricGaussianKL.make(**args, comm=comm) return if mode == 0: - kl0 = ift.MetricGaussianKL(**args, comm=comm) + kl0 = ift.MetricGaussianKL.make(**args, comm=comm) locsamp = kl0._local_samples - kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp) + if isinstance(mean0, ift.MultiField): + _, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants)) + else: + tmph = h + kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True) elif mode == 1: - kl0 = ift.MetricGaussianKL(**args) + kl0 = ift.MetricGaussianKL.make(**args) samples = kl0._local_samples ii = len(samples)//2 slc = slice(None, ii) if rank == 0 else slice(ii, None) locsamp = samples[slc] - kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp) + if isinstance(mean0, ift.MultiField): + _, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants)) + else: + tmph = h + kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True) # Test number of samples expected_nsamps = 2*nsamps if mirror_samples else nsamps diff --git a/test/test_sugar.py b/test/test_sugar.py index 73584cc6e8dcd6cf3b2342c976bad9cb65769375..61b59a47e2cca99b0cf3abd80461bdc278a24276 100644 --- a/test/test_sugar.py +++ b/test/test_sugar.py @@ -16,11 +16,15 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np +import pytest from numpy.testing import assert_equal import nifty7 as ift + from .common import setup_function, teardown_function +pmp = pytest.mark.parametrize + def test_get_signal_variance(): space = ift.RGSpace(3) @@ -45,15 +49,13 @@ def test_exec_time(): lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op1 ic = ift.GradientNormController(iteration_limit=2) ham = ift.StandardHamiltonian(lh, ic_samp=ic) - kl = ift.MetricGaussianKL(ift.full(ham.domain, 0.), ham, 1) + kl = ift.MetricGaussianKL.make(ift.full(ham.domain, 0.), ham, 1) ops = [op, op1, lh, ham, kl] for oo in ops: for wm in [True, False]: ift.exec_time(oo, wm) -import pytest -pmp = pytest.mark.parametrize @pmp('mf', [False, True]) @pmp('cplx', [False, True]) def test_calc_pos(mf, cplx):