hmc.py 17.3 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2019-2020 Max-Planck-Society

import numpy as np

from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..logger import logger
from ..multi_field import MultiField
from ..operators.harmonic_operators import FFTOperator
from ..operators.linear_operator import LinearOperator
from ..operators.scaling_operator import ScalingOperator
from ..probing import StatCalculator
from ..random import Context, current_rng, spawn_sseq
from ..sugar import makeField, makeOp
from ..utilities import allreduce_sum, get_MPI_params_from_comm, shareRange


def get_ESS(samples):
    """ The effective sample size over a set of samples.

    Returns
    -----------
    ESS: MultiField
    The effective sample size for all parameters in the samples.
    """
    result = {}
    autocorrelations = get_AFC(samples)
    for key in autocorrelations.keys():
        fld = autocorrelations[key].val
        addaxis = False
        if len(fld.shape) == 1:
            # empty domaintuple war weird
            fld = fld.reshape((1,) + fld.shape)
            addaxis = True
        cum_field = np.cumsum(fld, axis=-1)
        correlation_length = np.argmax(fld < 0, axis=-1)
        indices = np.where(np.ones(cum_field[..., 0].shape))
        indices += (correlation_length.flatten() - 1,)
        integr_corr = cum_field[indices] - 1
        ESS = len(samples)/(1 + 2*integr_corr)
        if addaxis:
            result[key[:-2]] = Field(samples[0].domain[key[:-2]], ESS[0])
        else:
            result[key[:-2]] = Field(samples[0].domain[key[:-2]],
                                     ESS.reshape(correlation_length.shape))
    return MultiField.from_dict(result)


def get_AFC(samples):
    """Calculates the auto-correlation function for every parameter in the samples.

    Returns
    -----------
    result: MultiField
    The auto-correlation function for every parameter in the samples
    """
    sample_field = _standardized_sample_field(samples)
    result = {}
    for key in sample_field.keys():
        AFC = ACF_Selector(sample_field[key].domain, len(samples))
        FFT = FFTOperator(sample_field[key].domain,
                          space=len(sample_field[key].domain._dom) - 1)
        h = FFT(sample_field[key])
        hch = h.conjugate()*h
        autocorr = FFT.inverse(hch)
        result[key] = AFC(autocorr).real
    return result


def _mean(fld, dom):
    result = {}
    for key in fld.keys():
        mean = fld[key].val.mean(axis=-1)
        result[key[:-2]] = makeField(dom[key[:-2]], mean)
    result = MultiField.from_dict(result, dom)
    return result


def _var(fld, dom):
    result = {}
    for key in fld.keys():
        var = fld[key].val.var(axis=-1)
        result[key[:-2]] = makeField(dom[key[:-2]], var)
    result = MultiField.from_dict(result, dom)
    return result


def _standardized_sample_field(samples):
    di = {}
    dom = samples[0].domain
    fld = _sample_field(samples)
    mean = _mean(fld, dom)
    var = _var(fld, dom)
    for key in dom.keys():
        sub_fld = fld[key + '_t']
        sub_dom = sub_fld.domain
        sub_fld = (sub_fld.val - mean[key].val[..., np.newaxis])/(
            var[key].val**0.5)[..., np.newaxis]
        di[key + '_t'] = makeField(sub_dom, sub_fld)
    return di


def _sample_field(samples):
    di = {}
    time_domain = RGSpace(len(samples))
    dom = samples[0].domain
    for key in dom.keys():
        fld = np.empty(dom[key].shape + time_domain.shape)
        fld_dom = DomainTuple.make(dom[key]._dom + (time_domain,))
        for i in range(time_domain.shape[0]):
            fld[..., i] = (samples[i][key]).val
        di[key + '_t'] = makeField(fld_dom, fld)
    return di


class HMC_chain:
    """Class for individual chains to perform the Hamiltonian Monte Carlo sampling.

    Parameters
    -----------
    V: EnergyOperator
        The problem Hamiltonian, used as potential energy in the Hamiltonian
        Dynamics of HMC.
    position:  Fields/MultiFields
        The position the chains are initialized.
    M: DiagonalOperator
        The mass matrix for the momentum term in the Hamiltonian dynamics.
        If not set, a unit matrix is assumed. Default: None
    steplength: Float
        The length of the steps in the leapfrog integration. This should be
        tuned to achieve reasonable acceptance for the given problem.
        Default: 0.003
    steps: positive Integer
        The number of leapfrog integration steps for the next sample.
        Default: 10
    """
    def __init__(self, V, position, M=None, steplength=0.003, steps=10, sseq=None):
        if sseq is None:
            raise RuntimeError
        if M is None:
            M = ScalingOperator(position.domain, 1)
        self._position = position
        self.samples = []
        self._M = M
        self._V = V
        self._steplength = steplength
        self._steps = steps
        self._energies = []
        self._accepted = []
        self._current_acceptance = []
        self._sseq = sseq

    def sample(self, N):
        """ The method to draw a set of samples.

        Parameters
        -----------
        N: positive Integer
        The number of samples to be drawn.
        """
        for i in range(N):
            self._sample()
            logger.info(f'iteration: {i} acceptance: {self._current_acceptance[-1]} steplength: {self._steplength}')

    def warmup(self, N, preferred_acceptance=0.6, keep=False):
        """ Performing a warmup by tuning the steplength
         to achieve a certain acceptance rate and estimating the mass matrix.

        Parameters
        -----------
        N: positive Integer
            The number of warmup samples to be drawn.
        preferred_acceptance: Float
            The acceptance rate according to which the stepsize is tuned.
            Default: 0.6
        keep: Boolean
            Whether to keep the drawn samples or discard them. Default: False
        """
        for i in range(N):
            self._sample()
            self._tune_parameters(preferred_acceptance)
            logger.info(f'WARMUP: {i} acceptance: {self._current_acceptance[-1]} steplength: {self._steplength}')
        sc = StatCalculator()
        for sample in self.samples:
            sc.add(sample)
        self.M = makeOp(sc.var).inverse
        if not keep:
            self.samples = []

    def estimate_quantity(self, function):
        """ Estimates the result of a function over all samples of the chains.

        Parameters
        -----------
        function: Function
            The function to be evaluated and averaged over the samples.

        Returns
        -----------
        mean, var : Tuple
            The mean and variance over the samples.
        """
        sc = StatCalculator()
        for sample in self.samples:
            sc.add(function(sample))
        return sc.mean, sc.var

    def _sample(self):
        """Draws one sample according to the HMC algorithm."""
        tmp = self._sseq.spawn(2)[1]
        with Context(tmp):
            momentum = self._M.draw_sample_with_dtype(dtype=np.float64)

        new_position, new_momentum = self._integrate(momentum)
        self._accepting(momentum, new_position, new_momentum)
        self._update_acceptance()

    def _integrate(self, momentum):
        """Performs the leapfrog integration of the equations of motion.

        Parameters
        -----------
        momentum: Field or Multifield
            The momentum vector in the Hamilton equations.
        """
        position = self._position
        for i in range(self._steps):
            position, momentum = self._leapfrog(position, momentum)
        return position, momentum

    def _leapfrog(self, position, momentum):
        """Performs one leapfrog integration step.

        Parameters
        -----------
        position: Field or Multifield
            The position vector in the Hamilton equations.
        momentum: Field or Multifield
            The momentum vector in the Hamilton equations.
        """
        lin = Linearization.make_var(position)
        gradient = self._V(lin).gradient
        momentum = momentum - self._steplength/2.*gradient
        position = position + self._steplength*self._M.inverse(momentum)
        lin = Linearization.make_var(position)
        gradient = self._V(lin).gradient
        momentum = momentum - self._steplength/2.*gradient
        return position, momentum

    def _accepting(self, momentum, new_position, new_momentum):
        """ Decides whether to accept or decline a new position according to
        Metropolis-Hastings.

        The current position is then stored as new sample.

        Parameters
        -----------
        momentum: Field or Multifield
            The old momentum vector in the Hamilton equations.
        new_position: Field or Multifield
            The new position vector after evolving the equations of motion.
        new_momentum: Field or Multifield
            The new momentum vector after evolving the equations of motion.
        """
        energy = self._V(self._position).val + (
            0.5*momentum.vdot(self._M.inverse(momentum))).val
        new_energy = self._V(new_position).val + (
            0.5*new_momentum.vdot(self._M.inverse(new_momentum))).val
        if new_energy < energy:
            self._position = new_position
            accept = 1
        else:
            rate = np.exp(energy - new_energy)
            if np.isnan(rate):
                return
            rng = current_rng()
            accept = rng.binomial(1, rate)
            if accept:
                self._position = new_position
        self._accepted.append(accept)
        self.samples.append(self._position)
        self._energies.append(energy)

    def _update_acceptance(self):
        """Calculates the current acceptance rate based on the last ten samples."""
        current_accepted = self._accepted[-10:]
        current_accepted = np.array(current_accepted)
        current_acceptance = np.mean(current_accepted)
        self._current_acceptance.append(current_acceptance)

    def _tune_parameters(self, preferred_acceptance):
        """Increases or decreases the steplength in the leapfrog integration
        based on the current acceptance rate to aim for the preferred rate.

        Parameters
        -----------
        preferred_acceptance: Float
            The preferred acceptance rate.
        """
        if self._current_acceptance[-1] < preferred_acceptance:
            self._steplength *= 0.99
        else:
            self._steplength *= 1.01

    @property
    def ESS(self):
        """The effective sample size over all samples of the chain.

        Returns
        -----------
        ESS: MultiField
            The effective sample size of all model parameters of the chain.
        """
        return get_ESS(self.samples)

    def mean(self):
        """The mean over all samples of the chain.

        Returns
        -----------
        mean: Field or MultiField
            The mean over all samples of the chain.
        """
        return self._mean(self._sample_field())

    def _mean(self, fld):
        result = {}
        dom = self._position.domain
        for key in fld.keys():
            mean = fld[key].val.mean(axis=-1)
            result[key[:-2]] = Field.make(dom[key[:-2]], mean)
        return MultiField.from_dict(result, self._position.domain)

    def _var(self, fld):
        result = {}
        dom = self._position.domain
        for key in fld.keys():
            var = fld[key].val.var(axis=-1)
            result[key[:-2]] = Field.make(dom[key[:-2]], var)
        return MultiField.from_dict(result, self._position.domain)


class HMC_Sampler:
    """The sampler class, managing chains and the computations of diagnostics.

    Parameters
    -----------
    V: EnergyOperator
        The problem Hamiltonian, used as potential energy in the Hamiltonian
        Dynamics of HMC.
    initial_position: List of Fields/MultiFields
        The position the chains are initialized.
    chains: positive Integer
        The number of chains. Default: 1
    M: DiagonalOperator
        The mass matrix for the momentum term in the Hamiltonian dynamics.
        If not set, a unit matrix is assumed. Default: None
    steplength: Float
        The length of the steps in the leapfrog integration. This should be
        tuned to achieve reasonable acceptance for the given problem.
        Default: 0.003
    steps: positive Integer
        The number of leapfrog integration steps for the next sample.
        Default: 10
    """
    def __init__(self, V, initial_position, chains=1, M=None, steplength=0.003, steps=10, comm=None):
        self._M = M
        self._V = V
        self._dom = initial_position[0].domain  # FIXME temporary!
        self._steplength = steplength
        self._steps = steps
        self._N_chains = chains
        sseq = spawn_sseq(self._N_chains)
        self._local_chains = []
        self._comm = comm
        ntask, rank, _ = get_MPI_params_from_comm(self._comm)
        lo, hi = shareRange(self._N_chains, ntask, rank)
        for i in range(lo, hi):
            self._local_chains.append(
                HMC_chain(self._V, initial_position[i], self._M,
                          self._steplength, self._steps, sseq[i]))

    def sample(self, N):
        """The method to draw a set of samples in every chain.

        Parameters
        -----------
        N: positive Integer
            The number of samples to be drawn in every chain.
        """
        for chain in self._local_chains:
            chain.sample(N)

    def warmup(self, N, preferred_acceptance=0.6, keep=False):
        """Performing a warmup by tuning the steplength to achieve a certain
        acceptance rate and estimating the mass matrix.

        Parameters
        -----------
        N: positive Integer
            The number of warmup samples to be drawn in every chain.
        preferred_acceptance: Float
            The acceptance rate according to which the stepsize is tuned.
            Default: 0.6
        keep: Boolean
            Whether to keep the drawn samples or discard them. Default: False
        """
        for chain in self._local_chains:
            chain.warmup(N, preferred_acceptance, keep)

    def estimate_quantity(self, function):
        """Estimates the result of a function over all samples and chains.

        Parameters
        -----------
        function: Function
        The function to be evaluated and averaged over the samples.

        Returns
        -----------
        mean, var : Tuple
        The mean and variance over the samples.

        """
        locmeanvar = [
            chain.estimate_quantity(function) for chain in self._local_chains
        ]
        locmean = [x[0] for x in locmeanvar]
        locvar = [x[1] for x in locmeanvar]
        mean = allreduce_sum(locmean, self._comm)
        var = allreduce_sum(locvar, self._comm)
        return mean/self._N_chains, var/self._N_chains

    @property
    def ESS(self):
        """The effective sample size over all samples and chains.

        Returns
        -----------
        ESS: MultiField
            The effective sample size of all model parameters.
        """
        return allreduce_sum([chain.ESS for chain in self._local_chains], self._comm)

    @property
    def R_hat(self):
        """The Gelman-Rubin test statistic R_hat.

        It measures how well the samples of different chains agree to determine
        convergence. Ideally this quantity is close to unity.

        Returns
        -----------
        R_hat: Field or MultiField
            The value of R_hat for all model parameters.
        """
        ntask, rank, master = get_MPI_params_from_comm(self._comm)
        N = len(self._local_chains[0].samples) if master else None
        if ntask > 1:
            N = self._comm.bcast(N, root=0)
        M = self._N_chains
        dom = self._dom
        locfld = [_sample_field(chain.samples) for chain in self._local_chains]
        locmeanmean = [_mean(fld, dom) for fld in locfld]
        locW = [_var(fld, dom) for fld in locfld]
        mean_mean = allreduce_sum(locmeanmean, self._comm)/M
        W = allreduce_sum(locW, self._comm)/M
        locB = [(mean_mean - _mean(fld, dom))**2 for fld in locfld]
        B = allreduce_sum(locB, self._comm)*N/(M - 1)
        var_theta = (1 - 1/N)*W + (M + 1)/(N*M)*B
        return (var_theta/W).sqrt()


class ACF_Selector(LinearOperator):
    def __init__(self, domain, N_samps):
        self._domain = DomainTuple.make(domain)
        self._N_samps = N_samps
        us_dom = UnstructuredDomain(self._N_samps//2)
        self._target = DomainTuple.make(self.domain._dom[:-1] + (us_dom,))
        self._capability = self.TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        return makeField(self._target, x.val[..., :self._N_samps//2])