kl_energy.py 3.55 KB
Newer Older
1
2
3
4
5
6
7
8
from __future__ import absolute_import, division, print_function

from ..compat import *
from .energy import Energy
from ..linearization import Linearization
from ..operators.scaling_operator import ScalingOperator
from ..operators.block_diagonal_operator import BlockDiagonalOperator
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
9
10
11
12
13
14
15
16
17
from ..field import Field
from ..multi_field import MultiField

from mpi4py import MPI
import numpy as np
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
rank = _comm.Get_rank()
master = (rank == 0)
Philipp Arras's avatar
Philipp Arras committed
18
19


Martin Reinecke's avatar
Martin Reinecke committed
20
21
22
23
24
25
def _shareRange(nwork, nshares, myshare):
    nbase = nwork//nshares
    additional = nwork % nshares
    lo = myshare*nbase + min(myshare, additional)
    hi = lo + nbase + int(myshare < additional)
    return lo, hi
Philipp Arras's avatar
Philipp Arras committed
26
27


Martin Reinecke's avatar
Martin Reinecke committed
28
29
30
31
def np_allreduce_sum(arr):
    res = np.empty_like(arr)
    _comm.Allreduce(arr, res, MPI.SUM)
    return res
Philipp Arras's avatar
Philipp Arras committed
32
33


Martin Reinecke's avatar
Martin Reinecke committed
34
35
def allreduce_sum_field(fld):
    if isinstance(fld, Field):
Philipp Arras's avatar
Philipp Arras committed
36
37
38
39
40
        return Field.from_local_data(fld.domain,
                                     np_allreduce_sum(fld.local_data))
    res = tuple(
        Field.from_local_data(f.domain, np_allreduce_sum(f.local_data))
        for f in fld.values())
Martin Reinecke's avatar
Martin Reinecke committed
41
    return MultiField(fld.domain, res)
42
43
44


class KL_Energy(Energy):
Philipp Arras's avatar
Philipp Arras committed
45
46
47
48
49
50
    def __init__(self,
                 position,
                 h,
                 nsamp,
                 constants=[],
                 _samples=None,
51
                 want_metric=False):
52
53
        super(KL_Energy, self).__init__(position)
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
54
        self._nsamp = nsamp
55
        self._constants = constants
56
        self._want_metric = want_metric
Philipp Arras's avatar
Philipp Arras committed
57
58
59
60
61
62
        if nsamp < ntask:
            # FIXME We need a better solution here. It is probably not good if
            # the script just dies. Can we proceed anyways?
            print('Number of samples: {}, number of MPI tasks: {}'.format(
                nsamp, ntask))
            raise RuntimeError('Cannot use more tasks than samples.')
63
        if _samples is None:
Martin Reinecke's avatar
Martin Reinecke committed
64
            lo, hi = _shareRange(nsamp, ntask, rank)
65
            met = h(Linearization.make_var(position, True)).metric
Martin Reinecke's avatar
Martin Reinecke committed
66
67
68
69
70
            _samples = []
            for i in range(lo, hi):
                np.random.seed(i)
                _samples.append(met.draw_sample(from_inverse=True))
        self._samples = tuple(_samples)
71
        if len(constants) == 0:
72
            tmp = Linearization.make_var(position, want_metric)
73
        else:
Philipp Arras's avatar
Philipp Arras committed
74
75
76
77
            ops = [
                ScalingOperator(0. if key in constants else 1., dom)
                for key, dom in position.domain.items()
            ]
78
            bdop = BlockDiagonalOperator(position.domain, tuple(ops))
79
            tmp = Linearization(position, bdop, want_metric=want_metric)
Philipp Arras's avatar
Philipp Arras committed
80
81
        mymap = map(lambda v: self._h(tmp + v), self._samples)
        tmp = utilities.my_sum(mymap)*(1./self._nsamp)
Martin Reinecke's avatar
Martin Reinecke committed
82
83
        self._val = np_allreduce_sum(tmp.val.local_data)[()]
        self._grad = allreduce_sum_field(tmp.gradient)
84
85
86
        self._metric = tmp.metric

    def at(self, position):
Martin Reinecke's avatar
Martin Reinecke committed
87
88
        return KL_Energy(position, self._h, self._nsamp, self._constants,
                         self._samples, self._want_metric)
89
90
91
92
93
94
95
96
97
98

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def apply_metric(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
99
        return allreduce_sum_field(self._metric(x))
100

Martin Reinecke's avatar
Martin Reinecke committed
101
102
    @property
    def metric(self):
103
104
        if ntask > 1:
            raise ValueError("not supported when MPI is active")
Martin Reinecke's avatar
Martin Reinecke committed
105
106
        return self._metric

107
108
    @property
    def samples(self):
109
110
111
        res = _comm.allgather(self._samples)
        res = [item for sublist in res for item in sublist]
        return res