kl_energies.py 21.6 KB
Newer Older
Jakob Knollmüller's avatar
fix    
Jakob Knollmüller committed
1
# This program is free software: you can redistribute it and/or modify
Philipp Frank's avatar
Philipp Frank committed
2
3
4
5
6
7
8
9
10
11
12
13
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2021 Max-Planck-Society
Philipp Arras's avatar
Philipp Arras committed
15
# Authors: Philipp Frank, Philipp Arras
Philipp Frank's avatar
Philipp Frank committed
16
17
18
19
20
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from functools import reduce

Philipp Arras's avatar
Philipp Arras committed
21
22
23
import numpy as np

from .. import random, utilities
Philipp Arras's avatar
Philipp Arras committed
24
from ..domain_tuple import DomainTuple
Philipp Frank's avatar
Philipp Frank committed
25
26
from ..linearization import Linearization
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
27
from ..operators.adder import Adder
Philipp Frank's avatar
Philipp Frank committed
28
from ..operators.endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
29
30
from ..operators.energy_operators import GaussianEnergy, StandardHamiltonian
from ..operators.inversion_enabler import InversionEnabler
Philipp Frank's avatar
Philipp Frank committed
31
from ..operators.sampling_enabler import SamplingDtypeSetter
Philipp Arras's avatar
Philipp Arras committed
32
from ..operators.sandwich_operator import SandwichOperator
Philipp Frank's avatar
Philipp Frank committed
33
34
35
36
from ..operators.scaling_operator import ScalingOperator
from ..probing import approximation2endo
from ..sugar import makeOp
from ..utilities import myassert
Philipp Arras's avatar
Philipp Arras committed
37
from .descent_minimizers import ConjugateGradient, DescentMinimizer
Philipp Frank's avatar
Philipp Frank committed
38
from .energy import Energy
Philipp Arras's avatar
Philipp Arras committed
39
40
from .energy_adapter import EnergyAdapter
from .quadratic_energy import QuadraticEnergy
Philipp Frank's avatar
Philipp Frank committed
41
42
43
44
45
46
47
48
49
50
51


def _get_lo_hi(comm, n_samples):
    ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
    return utilities.shareRange(n_samples, ntask, rank)


def _modify_sample_domain(sample, domain):
    """Takes only keys from sample which are also in domain and inserts zeros
    for keys which are not in sample.domain."""
    from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
52
53
    from ..field import Field
    from ..multi_domain import MultiDomain
Philipp Frank's avatar
Philipp Frank committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    from ..sugar import makeDomain
    domain = makeDomain(domain)
    if isinstance(domain, DomainTuple) and isinstance(sample, Field):
        if sample.domain is not domain:
            raise TypeError
        return sample
    elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField):
        if sample.domain is domain:
            return sample
        out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
        out = MultiField.from_dict(out, domain)
        return out
    raise TypeError


Philipp Arras's avatar
Philipp Arras committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def _reduce_by_keys(field, operator, keys):
    """Partially insert a field into an operator

    If the domain of the operator is an instance of `DomainTuple`

    Parameters
    ----------
    field : Field or MultiField
        Potentially partially constant input field.
    operator : Operator
        Operator into which `field` is partially inserted.
    keys : list
        List of constant `MultiDomain` entries.

    Returns
    -------
    list
        The variable part of the field and the contracted operator.
    """
    from ..sugar import is_fieldlike, is_operator
    myassert(is_fieldlike(field))
    myassert(is_operator(operator))
    if isinstance(field, MultiField):
        cst_field = field.extract_by_keys(keys)
        var_field = field.extract_by_keys(set(field.keys()) - set(keys))
        _, new_ham = operator.simplify_for_constant_input(cst_field)
        return var_field, new_ham
    myassert(len(keys) == 0)
    return field, operator
Philipp Frank's avatar
Philipp Frank committed
98
99


100
101
102
103
class _SelfAdjointOperatorWrapper(EndomorphicOperator):
    def __init__(self, domain, func):
        from ..sugar import makeDomain
        self._func = func
Philipp Frank's avatar
Philipp Frank committed
104
        self._capability = self.TIMES | self.ADJOINT_TIMES
105
        self._domain = makeDomain(domain)
Philipp Frank's avatar
Philipp Frank committed
106
107
108

    def apply(self, x, mode):
        self._check_input(x, mode)
109
        return self._func(x)
Philipp Frank's avatar
Philipp Frank committed
110
111
112


class _SampledKLEnergy(Energy):
Philipp Arras's avatar
Philipp Arras committed
113
114
115
116
117
    """Base class for Energies representing a sampled Kullback-Leibler
    divergence for the variational approximation of a distribution with another
    distribution.

    Supports the samples to be distributed across MPI tasks."""
Philipp Frank's avatar
Philipp Frank committed
118
119
120
121
122
123
124
125
126
127
    def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm,
                 local_samples, nanisinf):
        super(_SampledKLEnergy, self).__init__(mean)
        myassert(mean.domain is hamiltonian.domain)
        self._hamiltonian = hamiltonian
        self._n_samples = int(n_samples)
        self._mirror_samples = bool(mirror_samples)
        self._comm = comm
        self._local_samples = local_samples
        self._nanisinf = bool(nanisinf)
Philipp Arras's avatar
Philipp Arras committed
128

Philipp Frank's avatar
Philipp Frank committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        lin = Linearization.make_var(mean)
        v, g = [], []
        for s in self._local_samples:
            s = _modify_sample_domain(s, mean.domain)
            tmp = hamiltonian(lin+s)
            tv = tmp.val.val
            tg = tmp.gradient
            if mirror_samples:
                tmp = hamiltonian(lin-s)
                tv = tv + tmp.val.val
                tg = tg + tmp.gradient
            v.append(tv)
            g.append(tg)
        self._val = utilities.allreduce_sum(v, self._comm)[()]/self.n_eff_samples
        if np.isnan(self._val) and self._nanisinf:
            self._val = np.inf
        self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def at(self, position):
        return _SampledKLEnergy(
            position, self._hamiltonian, self._n_samples, self._mirror_samples,
            self._comm, self._local_samples, self._nanisinf)

    def apply_metric(self, x):
        lin = Linearization.make_var(self.position, want_metric=True)
        res = []
        for s in self._local_samples:
            s = _modify_sample_domain(s, self._hamiltonian.domain)
            tmp = self._hamiltonian(lin+s).metric(x)
            if self._mirror_samples:
                tmp = tmp + self._hamiltonian(lin-s).metric(x)
            res.append(tmp)
        return utilities.allreduce_sum(res, self._comm)/self.n_eff_samples

    @property
    def n_eff_samples(self):
        if self._mirror_samples:
            return 2*self._n_samples
        return self._n_samples

    @property
    def metric(self):
179
180
        return _SelfAdjointOperatorWrapper(self.position.domain,
                                           self.apply_metric)
Philipp Frank's avatar
Philipp Frank committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    @property
    def samples(self):
        ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
        if ntask == 1:
            for s in self._local_samples:
                yield s
                if self._mirror_samples:
                    yield -s
        else:
            rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
            lo, _ = _get_lo_hi(self._comm, self._n_samples)
            for itask, (l, h) in enumerate(rank_lo_hi):
                for i in range(l, h):
                    data = self._local_samples[i-lo] if rank == itask else None
                    s = self._comm.bcast(data, root=itask)
                    yield s
                    if self._mirror_samples:
                        yield -s


202
203
204
205
206
207
208
class _MetricGaussianSampler:
    def __init__(self, position, H, n_samples, mirror_samples, napprox=0):
        if not isinstance(H, StandardHamiltonian):
            raise NotImplementedError
        lin = Linearization.make_var(position.extract(H.domain), True)
        self._met = H(lin).metric
        if napprox >= 1:
Philipp Frank's avatar
Philipp Frank committed
209
            self._met._approximation = makeOp(approximation2endo(self._met, napprox))
210
211
212
213
        self._n = int(n_samples)

    def draw_samples(self, comm):
        local_samples = []
214
        utilities.check_MPI_synced_random_state(comm)
215
216
217
218
219
220
221
        sseq = random.spawn_sseq(self._n)
        for i in range(*_get_lo_hi(comm, self._n)):
            with random.Context(sseq[i]):
                local_samples.append(self._met.draw_sample(from_inverse=True))
        return tuple(local_samples)


Philipp Frank's avatar
Philipp Frank committed
222
class _GeoMetricSampler:
Philipp Arras's avatar
Philipp Arras committed
223
    def __init__(self, position, H, minimizer, start_from_lin,
Philipp Frank's avatar
Philipp Frank committed
224
225
226
                 n_samples, mirror_samples, napprox=0, want_error=False):
        if not isinstance(H, StandardHamiltonian):
            raise NotImplementedError
Philipp Arras's avatar
Philipp Arras committed
227
228
229
230

        # Check domain dtype
        dts = H._prior._met._dtype
        if isinstance(H.domain, DomainTuple):
231
            real = np.issubdtype(dts, np.floating)
Philipp Arras's avatar
Philipp Arras committed
232
        else:
233
            real = all([np.issubdtype(dts[kk], np.floating) for kk in dts.keys()])
Philipp Arras's avatar
Philipp Arras committed
234
        if not real:
Philipp Frank's avatar
Philipp Frank committed
235
            raise ValueError("_GeoMetricSampler only supports real valued latent DOFs.")
Philipp Arras's avatar
Philipp Arras committed
236
237
        # /Check domain dtype

Philipp Frank's avatar
Philipp Frank committed
238
239
240
241
242
243
244
245
246
247
        if isinstance(position, MultiField):
            self._position = position.extract(H.domain)
        else:
            self._position = position
        tr = H._lh.get_transformation()
        if tr is None:
            raise ValueError("_GeoMetricSampler only works for likelihoods")
        dtype, f_lh = tr
        scale = ScalingOperator(f_lh.target, 1.)
        if isinstance(dtype, dict):
Philipp Arras's avatar
Philipp Arras committed
248
            sampling = reduce((lambda a,b: a*b),
Philipp Frank's avatar
Philipp Frank committed
249
250
251
252
                              [dtype[k] is not None for k in dtype.keys()])
        else:
            sampling = dtype is not None
        scale = SamplingDtypeSetter(scale, dtype) if sampling else scale
Philipp Arras's avatar
Philipp Arras committed
253

Philipp Frank's avatar
Philipp Frank committed
254
        fl = f_lh(Linearization.make_var(self._position))
Philipp Arras's avatar
Philipp Arras committed
255
        self._g = (Adder(-self._position) + fl.jac.adjoint@Adder(-fl.val)@f_lh)
Philipp Frank's avatar
Philipp Frank committed
256
257
258
        self._likelihood = SandwichOperator.make(fl.jac, scale)
        self._prior = SamplingDtypeSetter(ScalingOperator(fl.domain,1.), np.float64)
        self._met = self._likelihood + self._prior
Philipp Arras's avatar
Philipp Arras committed
259
        if napprox >= 1:
Philipp Frank's avatar
Philipp Frank committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            self._approximation = makeOp(approximation2endo(self._met, napprox)).inverse
        else:
            self._approximation = None
        self._ic = H._ic_samp
        self._minimizer = minimizer
        self._start_from_lin = start_from_lin
        self._want_error = want_error

        sseq = random.spawn_sseq(n_samples)
        if mirror_samples:
            mysseq = []
            for seq in sseq:
                mysseq += [seq, seq]
        else:
            mysseq = sseq
        self._sseq = mysseq
        self._neg = (False, True)*n_samples if mirror_samples else (False, )*n_samples
        self._n_samples = n_samples
        self._mirror_samples = mirror_samples

    @property
    def n_eff_samples(self):
        return 2*self._n_samples if self._mirror_samples else self._n_samples

    @property
    def position(self):
        return self._position

    def _draw_lin(self, neg):
        s = self._prior.draw_sample(from_inverse=True)
        s = -s if neg else s
        nj = self._likelihood.draw_sample()
        nj = -nj if neg else nj
        y = self._prior(s) + nj
        if self._start_from_lin:
            energy = QuadraticEnergy(s, self._met, y,
                                     _grad=self._likelihood(s) - nj)
            inverter = ConjugateGradient(self._ic)
            energy, convergence = inverter(energy,
                                           preconditioner=self._approximation)
            yi = energy.position
        else:
            yi = s
        return y, yi

    def _draw_nonlin(self, y, yi):
        en = EnergyAdapter(self._position+yi, GaussianEnergy(mean=y)@self._g,
                           nanisinf=True, want_metric=True)
        en, _ = self._minimizer(en)
        sam = en.position - self._position
        if self._want_error:
            er = y - self._g(sam)
            er = er.s_vdot(InversionEnabler(self._met, self._ic).inverse(er))
            return sam, er
        return sam

    def draw_samples(self, comm):
        local_samples = []
        prev = None
319
320
        utilities.check_MPI_synced_random_state(comm)
        utilities.check_MPI_equality(self._sseq, comm)
Philipp Frank's avatar
Philipp Frank committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        for i in range(*_get_lo_hi(comm, self.n_eff_samples)):
            with random.Context(self._sseq[i]):
                neg = self._neg[i]
                if (prev is None) or not self._mirror_samples:
                    y, yi = self._draw_lin(neg)
                    if not neg:
                        prev = (-y, -yi)
                else:
                    (y, yi) = prev
                    prev = None
                local_samples.append(self._draw_nonlin(y, yi))
        return tuple(local_samples)


def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
                     point_estimates=[], napprox=0, comm=None, nanisinf=False):
    """Provides the sampled Kullback-Leibler divergence between a distribution
    and a Metric Gaussian.

    A Metric Gaussian is used to approximate another probability distribution.
    It is a Gaussian distribution that uses the Fisher information metric of
    the other distribution at the location of its mean to approximate the
    variance. In order to infer the mean, a stochastic estimate of the
    Kullback-Leibler divergence is minimized. This estimate is obtained by
    sampling the Metric Gaussian at the current mean. During minimization
    these samples are kept constant; only the mean is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by intantiating `MetricGaussianKL` again. For the
    true probability distribution the standard parametrization is assumed.
    The samples of this class can be distributed among MPI tasks.

    Parameters
    ----------
    mean : Field
        Mean of the Gaussian probability distribution.
    hamiltonian : StandardHamiltonian
        Hamiltonian of the approximated probability distribution.
    n_samples : integer
        Number of samples used to stochastically estimate the KL.
    mirror_samples : boolean
        Whether the negative of the drawn samples are also used, as they are
        equally legitimate samples. If true, the number of used samples
        doubles. Mirroring samples stabilizes the KL estimate as extreme
        sample variation is counterbalanced. Since it improves stability in
        many cases, it is recommended to set `mirror_samples` to `True`.
    constants : list
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
    point_estimates : list
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
    napprox : int
        Number of samples for computing preconditioner for sampling. No
        preconditioning is done by default.
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
        its mirror image will always reside on the same task.
    nanisinf : bool
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
Philipp Arras's avatar
Philipp Arras committed
383
        these occasions but rather the minimizer is told that the position it
Philipp Frank's avatar
Philipp Frank committed
384
385
        has tried is not sensible.

Philipp Arras's avatar
Philipp Arras committed
386
387
    Note
    ----
Philipp Frank's avatar
Philipp Frank committed
388
389
390
    The two lists `constants` and `point_estimates` are independent from each
    other. It is possible to sample along domains which are kept constant
    during minimization and vice versa.
Philipp Arras's avatar
Philipp Arras committed
391

Philipp Frank's avatar
Philipp Frank committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    See also
    --------
    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
    """
    if not isinstance(hamiltonian, StandardHamiltonian):
        raise TypeError
    if hamiltonian.domain is not mean.domain:
        raise ValueError
    if not isinstance(n_samples, int):
        raise TypeError
    if not isinstance(mirror_samples, bool):
        raise TypeError
    if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
        raise RuntimeError(
            'Point estimates for whole domain. Use EnergyAdapter instead.')
    n_samples = int(n_samples)
    mirror_samples = bool(mirror_samples)

    _, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates)
412
    sampler = _MetricGaussianSampler(mean, ham_sampling, n_samples,
413
                                     mirror_samples, napprox)
414
415
    local_samples = sampler.draw_samples(comm)

Philipp Frank's avatar
Philipp Frank committed
416
417
418
419
420
    mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
    return _SampledKLEnergy(mean, hamiltonian, n_samples, mirror_samples, comm,
                            local_samples, nanisinf)


Philipp Arras's avatar
Philipp Arras committed
421
def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
Philipp Arras's avatar
Philipp Arras committed
422
                start_from_lin=True, constants=[], point_estimates=[],
Philipp Frank's avatar
Philipp Frank committed
423
424
425
426
427
428
                napprox=0, comm=None, nanisinf=True):
    """Provides the sampled Kullback-Leibler used in geometric Variational
    Inference (geoVI).

    In geoVI a probability distribution is approximated with a standard normal
    distribution in the canonical coordinate system of the Riemannian manifold
Philipp Arras's avatar
Philipp Arras committed
429
    associated with the metric of the other distribution. The coordinate
Philipp Frank's avatar
Philipp Frank committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    transformation is approximated by expanding around a point. In order to
    infer the expansion point, a stochastic estimate of the Kullback-Leibler
    divergence is minimized. This estimate is obtained by sampling from the
    approximation using the current expansion point. During minimization these
    samples are kept constant; only the expansion point is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by instantiating `GeoMetricKL` again. For the true
    probability distribution the standard parametrization is assumed.
    The samples of this class can be distributed among MPI tasks.

    Parameters
    ----------
    mean : Field
        Expansion point of the coordinate transformation.
    hamiltonian : StandardHamiltonian
        Hamiltonian of the approximated probability distribution.
    n_samples : integer
        Number of samples used to stochastically estimate the KL.
    minimizer_samp : DescentMinimizer
        Minimizer used to draw samples.
    mirror_samples : boolean
        Whether the mirrored version of the drawn samples are also used.
        If true, the number of used samples doubles.
        Mirroring samples stabilizes the KL estimate as extreme
        sample variation is counterbalanced.
    start_from_lin : boolean
        Whether the non-linear sampling should start using the inverse
Philipp Arras's avatar
Philipp Arras committed
457
        linearized transformation (i.e. the corresponding MGVI sample).
Philipp Frank's avatar
Philipp Frank committed
458
459
460
461
462
463
464
465
466
467
        If False, the minimization starts from the prior sample.
        Default is True.
    constants : list
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
    point_estimates : list
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
    napprox : int
Philipp Arras's avatar
Philipp Arras committed
468
        Number of samples for computing preconditioner for linear sampling.
Philipp Frank's avatar
Philipp Frank committed
469
470
471
472
        No preconditioning is done by default.
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
Philipp Arras's avatar
Philipp Arras committed
473
        its mirror image will preferably reside on the same task if necessary.
Philipp Frank's avatar
Philipp Frank committed
474
475
476
    nanisinf : bool
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
Philipp Arras's avatar
Philipp Arras committed
477
        these occasions but rather the minimizer is told that the position it
Philipp Frank's avatar
Philipp Frank committed
478
479
        has tried is not sensible.

Philipp Arras's avatar
Philipp Arras committed
480
481
    Note
    ----
Philipp Frank's avatar
Philipp Frank committed
482
483
484
485
486
    The two lists `constants` and `point_estimates` are independent from each
    other. It is possible to sample along domains which are kept constant
    during minimization and vice versa.
    DomainTuples should never be created using the constructor, but rather
    via the factory function :attr:`make`!
Philipp Arras's avatar
Philipp Arras committed
487

Philipp Arras's avatar
Philipp Arras committed
488
489
    Note
    ----
Philipp Arras's avatar
Philipp Arras committed
490
    As in MGVI, mirroring samples can help to stabilize the latent mean as it
Philipp Frank's avatar
Philipp Frank committed
491
492
    reduces sampling noise. But unlike MGVI a mirrored sample involves an
    additional solve of the non-linear transformation. Therefore, when using
Philipp Arras's avatar
Philipp Arras committed
493
494
495
496
    MPI, the mirrored samples also get distributed if enough tasks are
    available.  If there are more total samples than tasks, the mirrored
    counterparts try to reside on the same task as their non mirrored partners.
    This ensures that at least the starting position can be re-used.
Philipp Arras's avatar
Philipp Arras committed
497

Philipp Frank's avatar
Philipp Frank committed
498
499
500
501
    See also
    --------
    `Geometric Variational Inference`, Philipp Frank, Reimar Leike,
    Torsten A. Enßlin, `<https://arxiv.org/abs/2105.10470>`_
Philipp Arras's avatar
Philipp Arras committed
502
    `<https://doi.org/10.3390/e23070853>`_
Philipp Frank's avatar
Philipp Frank committed
503
504
505
506
507
508
509
510
511
512
513
514
    """
    if not isinstance(hamiltonian, StandardHamiltonian):
        raise TypeError
    if hamiltonian.domain is not mean.domain:
        raise ValueError
    if not isinstance(n_samples, int):
        raise TypeError
    if not isinstance(mirror_samples, bool):
        raise TypeError
    if not isinstance(minimizer_samp, DescentMinimizer):
        raise TypeError
    if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
Philipp Arras's avatar
Philipp Arras committed
515
516
517
        s = 'Point estimates for whole domain. Use EnergyAdapter instead.'
        raise RuntimeError(s)

Philipp Frank's avatar
Philipp Frank committed
518
519
520
521
522
    n_samples = int(n_samples)
    mirror_samples = bool(mirror_samples)

    _, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates)
    sampler = _GeoMetricSampler(mean, ham_sampling, minimizer_samp,
523
524
                                start_from_lin, n_samples, mirror_samples,
                                napprox)
Philipp Frank's avatar
Philipp Frank committed
525
526
527
528
    local_samples = sampler.draw_samples(comm)
    mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
    return _SampledKLEnergy(mean, hamiltonian, sampler.n_eff_samples, False,
                            comm, local_samples, nanisinf)