kl_energy.py 2.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
from __future__ import absolute_import, division, print_function

from ..compat import *
from .energy import Energy
from ..linearization import Linearization
from .. import utilities


class KL_Energy(Energy):
10
    def __init__(self, position, h, nsamp, constants=[],
11
                 constants_samples=None, gen_mirrored_samples=False,
12
                 _samples=None):
13
        super(KL_Energy, self).__init__(position)
Philipp Arras's avatar
Philipp Arras committed
14
15
        if h.domain is not position.domain:
            raise TypeError
16
17
        self._h = h
        self._constants = constants
18
19
20
        if constants_samples is None:
            constants_samples = constants
        self._constants_samples = constants_samples
21
        if _samples is None:
22
23
            met = h(Linearization.make_partial_var(
                position, constants_samples, True)).metric
24
25
            _samples = tuple(met.draw_sample(from_inverse=True)
                             for _ in range(nsamp))
26
            if gen_mirrored_samples:
27
                _samples += tuple(-s for s in _samples)
28
        self._samples = _samples
Martin Reinecke's avatar
Martin Reinecke committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42

        self._lin = Linearization.make_partial_var(position, constants)
        v, g = None, None
        for s in self._samples:
            tmp = self._h(self._lin+s)
            if v is None:
                v = tmp.val.local_data[()]
                g = tmp.gradient
            else:
                v += tmp.val.local_data[()]
                g = g + tmp.gradient
        self._val = v / len(self._samples)
        self._grad = g * (1./len(self._samples))
        self._metric = None
43
44

    def at(self, position):
45
46
47
        return KL_Energy(position, self._h, 0,
                         self._constants, self._constants_samples,
                         _samples = self._samples)
48
49
50
51
52
53
54
55
56

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

Martin Reinecke's avatar
Martin Reinecke committed
57
58
59
60
61
62
63
    def _get_metric(self):
        if self._metric is None:
            lin = self._lin.with_want_metric()
            mymap = map(lambda v: self._h(lin+v).metric, self._samples)
            self._metric = utilities.my_sum(mymap)
            self._metric = self._metric.scale(1./len(self._samples))

64
    def apply_metric(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
65
        self._get_metric()
66
67
        return self._metric(x)

Martin Reinecke's avatar
Martin Reinecke committed
68
69
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
70
        self._get_metric()
Martin Reinecke's avatar
Martin Reinecke committed
71
72
        return self._metric

73
74
75
    @property
    def samples(self):
        return self._samples
Philipp Arras's avatar
Philipp Arras committed
76
77
78
79

    def __repr__(self):
        return 'KL ({} samples):\n'.format(len(
            self._samples)) + utilities.indent(self._ham.__repr__())