kl_energy.py 3.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17
18
19
20
21
22
23

from .energy import Energy
from ..linearization import Linearization
from .. import utilities


class KL_Energy(Energy):
24
    def __init__(self, position, h, nsamp, constants=[],
25
                 constants_samples=None, gen_mirrored_samples=False,
26
                 _samples=None):
27
        super(KL_Energy, self).__init__(position)
Philipp Arras's avatar
Philipp Arras committed
28
29
        if h.domain is not position.domain:
            raise TypeError
30
31
        self._h = h
        self._constants = constants
32
33
34
        if constants_samples is None:
            constants_samples = constants
        self._constants_samples = constants_samples
35
        if _samples is None:
36
37
            met = h(Linearization.make_partial_var(
                position, constants_samples, True)).metric
38
39
            _samples = tuple(met.draw_sample(from_inverse=True)
                             for _ in range(nsamp))
40
            if gen_mirrored_samples:
41
                _samples += tuple(-s for s in _samples)
42
        self._samples = _samples
Martin Reinecke's avatar
Martin Reinecke committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56

        self._lin = Linearization.make_partial_var(position, constants)
        v, g = None, None
        for s in self._samples:
            tmp = self._h(self._lin+s)
            if v is None:
                v = tmp.val.local_data[()]
                g = tmp.gradient
            else:
                v += tmp.val.local_data[()]
                g = g + tmp.gradient
        self._val = v / len(self._samples)
        self._grad = g * (1./len(self._samples))
        self._metric = None
57
58

    def at(self, position):
59
60
        return KL_Energy(position, self._h, 0,
                         self._constants, self._constants_samples,
Martin Reinecke's avatar
Martin Reinecke committed
61
                         _samples=self._samples)
62
63
64
65
66
67
68
69
70

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

Martin Reinecke's avatar
Martin Reinecke committed
71
72
73
74
75
76
77
    def _get_metric(self):
        if self._metric is None:
            lin = self._lin.with_want_metric()
            mymap = map(lambda v: self._h(lin+v).metric, self._samples)
            self._metric = utilities.my_sum(mymap)
            self._metric = self._metric.scale(1./len(self._samples))

78
    def apply_metric(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
79
        self._get_metric()
80
81
        return self._metric(x)

Martin Reinecke's avatar
Martin Reinecke committed
82
83
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
84
        self._get_metric()
Martin Reinecke's avatar
Martin Reinecke committed
85
86
        return self._metric

87
88
89
    @property
    def samples(self):
        return self._samples
Philipp Arras's avatar
Philipp Arras committed
90
91
92
93

    def __repr__(self):
        return 'KL ({} samples):\n'.format(len(
            self._samples)) + utilities.indent(self._ham.__repr__())