energy_adapter.py 8.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2021 Max-Planck-Society
15
16
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

18
19
import numpy as np

Philipp Frank's avatar
Philipp Frank committed
20
from .. import random
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
22
from ..minimization.energy import Energy
Philipp Frank's avatar
Philipp Frank committed
23
24
25
from ..utilities import myassert, allreduce_sum
from ..multi_domain import MultiDomain
from ..sugar import from_random
Philipp Frank's avatar
Philipp Frank committed
26
27
from ..domain_tuple import DomainTuple

Martin Reinecke's avatar
Martin Reinecke committed
28
29

class EnergyAdapter(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
30
31
32
33
34
    """Helper class which provides the traditional Nifty Energy interface to
    Nifty operators with a scalar target domain.

    Parameters
    -----------
Philipp Arras's avatar
Philipp Arras committed
35
36
    position: Field or MultiField
        The position where the minimization process is started.
Philipp Arras's avatar
Philipp Arras committed
37
38
39
    op: EnergyOperator
        The expression computing the energy from the input data.
    constants: list of strings
Martin Reinecke's avatar
Martin Reinecke committed
40
41
42
        The component names of the operator's input domain which are assumed
        to be constant during the minimization process.
        If the operator's input domain is not a MultiField, this must be empty.
Philipp Arras's avatar
Philipp Arras committed
43
44
45
        Default: [].
    want_metric: bool
        If True, the class will provide a `metric` property. This should only
Martin Reinecke's avatar
Martin Reinecke committed
46
        be enabled if it is required, because it will most likely consume
Philipp Arras's avatar
Philipp Arras committed
47
        additional resources. Default: False.
48
49
50
51
52
    nanisinf : bool
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
        these occaisions but rather the minimizer is told that the position it
        has tried is not sensible.
Martin Reinecke's avatar
Martin Reinecke committed
53
54
    """

55
    def __init__(self, position, op, constants=[], want_metric=False,
56
57
                 nanisinf=False):
        if len(constants) > 0:
58
            cstpos = position.extract_by_keys(constants)
59
            _, op = op.simplify_for_constant_input(cstpos)
60
61
            varkeys = set(op.domain.keys()) - set(constants)
            position = position.extract_by_keys(varkeys)
62
63
        super(EnergyAdapter, self).__init__(position)
        self._op = op
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
64
        self._want_metric = want_metric
65
        lin = Linearization.make_var(position, want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
66
        tmp = self._op(lin)
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
67
        self._val = tmp.val.val[()]
Martin Reinecke's avatar
Martin Reinecke committed
68
69
        self._grad = tmp.gradient
        self._metric = tmp._metric
70
71
72
        self._nanisinf = bool(nanisinf)
        if self._nanisinf and np.isnan(self._val):
            self._val = np.inf
Martin Reinecke's avatar
Martin Reinecke committed
73
74

    def at(self, position):
75
76
        return EnergyAdapter(position, self._op, want_metric=self._want_metric,
                             nanisinf=self._nanisinf)
Martin Reinecke's avatar
Martin Reinecke committed
77
78
79

    @property
    def value(self):
Martin Reinecke's avatar
Martin Reinecke committed
80
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
81
82
83

    @property
    def gradient(self):
Martin Reinecke's avatar
Martin Reinecke committed
84
        return self._grad
Martin Reinecke's avatar
Martin Reinecke committed
85

Martin Reinecke's avatar
Martin Reinecke committed
86
87
88
89
    @property
    def metric(self):
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
90
91
    def apply_metric(self, x):
        return self._metric(x)
Philipp Frank's avatar
Philipp Frank committed
92
93
94


class StochasticEnergyAdapter(Energy):
Philipp Arras's avatar
Docs    
Philipp Arras committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    """Provide the energy interface for an energy operator where parts of the
    input are averaged instead of optimized.

    Specifically, a set of standard normal distributed samples are drawn for
    the input corresponding to `keys` and each sample is inserted partially
    into `op`. The resulting operators are then averaged.  The subdomain that
    is not sampled is left a stochastic average of an energy with the remaining
    subdomain being the DOFs that are considered to be optimization parameters.

    Notes
    -----
    `StochasticEnergyAdapter` should never be created using the constructor,
    but rather via the factory function :attr:`make`.
    """
Philipp Frank's avatar
Philipp Frank committed
109
    def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
Philipp Arras's avatar
Philipp Arras committed
110
                 sampling_context, _callingfrommake=False):
Philipp Frank's avatar
Philipp Frank committed
111
112
113
        if not _callingfrommake:
            raise NotImplementedError
        super(StochasticEnergyAdapter, self).__init__(position)
Philipp Frank's avatar
Philipp Frank committed
114
115
        for lop in local_ops:
            myassert(position.domain == lop.domain)
Philipp Frank's avatar
Philipp Frank committed
116
117
118
        self._comm = comm
        self._local_ops = local_ops
        self._n_samples = n_samples
Jakob Knollmüller's avatar
Jakob Knollmüller committed
119
        self._nanisinf = nanisinf
Philipp Frank's avatar
Philipp Frank committed
120
121
        lin = Linearization.make_var(position)
        v, g = [], []
Philipp Frank's avatar
Philipp Frank committed
122
123
        for lop in self._local_ops:
            tmp = lop(lin)
Philipp Frank's avatar
Philipp Frank committed
124
125
126
127
128
129
130
            v.append(tmp.val.val)
            g.append(tmp.gradient)
        self._val = allreduce_sum(v, self._comm)[()]/self._n_samples
        if np.isnan(self._val) and self._nanisinf:
            self._val = np.inf
        self._grad = allreduce_sum(g, self._comm)/self._n_samples

Philipp Frank's avatar
Philipp Frank committed
131
        self._op = op
Philipp Frank's avatar
Philipp Frank committed
132
        self._keys = keys
Philipp Arras's avatar
Philipp Arras committed
133
        self._context = sampling_context
Philipp Frank's avatar
Philipp Frank committed
134
135
136
137
138
139
140
141
142
143

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def at(self, position):
Philipp Frank's avatar
Philipp Frank committed
144
145
        return StochasticEnergyAdapter(position, self._op, self._keys,
                    self._local_ops, self._n_samples, self._comm, self._nanisinf,
Philipp Arras's avatar
Philipp Arras committed
146
                    self._context, _callingfrommake=True)
Philipp Frank's avatar
Philipp Frank committed
147
148
149
150
151
152
153
154
155
156

    def apply_metric(self, x):
        lin = Linearization.make_var(self.position, want_metric=True)
        res = []
        for op in self._local_ops:
            res.append(op(lin).metric(x))
        return allreduce_sum(res, self._comm)/self._n_samples

    @property
    def metric(self):
Philipp Frank's avatar
Philipp Frank committed
157
        from .kl_energies import _SelfAdjointOperatorWrapper
Philipp Frank's avatar
Philipp Frank committed
158
159
160
161
162
        return _SelfAdjointOperatorWrapper(self.position.domain,
                                           self.apply_metric)

    def resample_at(self, position):
        return StochasticEnergyAdapter.make(position, self._op, self._keys,
Philipp Frank's avatar
Philipp Frank committed
163
                                            self._n_samples, self._comm)
Philipp Frank's avatar
Philipp Frank committed
164

Philipp Arras's avatar
Philipp Arras committed
165
166
    @classmethod
    def make(cls, position, op, sampling_keys, n_samples, mirror_samples,
Philipp Arras's avatar
Docs    
Philipp Arras committed
167
168
169
             comm=None, nanisinf=False):
        """Factory function for StochasticEnergyAdapter.

Philipp Frank's avatar
Philipp Frank committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        Parameters
        ----------
        position : MultiField
            Values of the optimization parameters
        op : Operator
            The objective function of the optimization problem. Must have a
            scalar target. The domain must be a `MultiDomain` with its keys
            being the union of `sampling_keys` and `position.domain.keys()`.
        sampling_keys : iterable of String
            The keys of the subdomain over which the stochastic average of `op`
            should be performed.
        n_samples : int
            Number of samples used for the stochastic estimate.
        mirror_samples : boolean
            Whether the negative of the drawn samples are also used, as they are
            equally legitimate samples. If true, the number of used samples
            doubles.
        comm : MPI communicator or None
            If not None, samples will be distributed as evenly as possible
            across this communicator. If `mirror_samples` is set, then a sample
            and its mirror image will always reside on the same task.
        nanisinf : bool
Philipp Arras's avatar
Docs    
Philipp Arras committed
192
193
194
            If true, nan energies, which can occur due to overflows in the
            forward model, are interpreted as inf which can be interpreted by
            optimizers.
Philipp Frank's avatar
Philipp Frank committed
195
        """
Philipp Frank's avatar
Philipp Frank committed
196
        myassert(op.target == DomainTuple.scalar_domain())
Philipp Frank's avatar
Philipp Frank committed
197
        samdom = {}
Philipp Frank's avatar
Philipp Frank committed
198
199
200
201
        if not isinstance(n_samples, int):
            raise TypeError
        for k in sampling_keys:
            if (k in position.domain.keys()) or (k not in op.domain.keys()):
Philipp Frank's avatar
Philipp Frank committed
202
                raise ValueError
Philipp Frank's avatar
Philipp Frank committed
203
            samdom[k] = op.domain[k]
Philipp Frank's avatar
Philipp Frank committed
204
        samdom = MultiDomain.make(samdom)
205
206
        seed = int(random.current_rng().integers(0, 1000000))
        context = op, position, seed, comm, n_samples, mirror_samples, samdom
Philipp Arras's avatar
Philipp Arras committed
207
208
209
210
211
212
213
        noise = cls._draw_noise(*context)
        local_ops = [op.simplify_for_constant_input(nn)[1] for nn in noise]
        n_samples = 2*n_samples if mirror_samples else n_samples
        return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
                     n_samples, comm, nanisinf, context, _callingfrommake=True)

    @staticmethod
214
    def _draw_noise(op, position, seed, comm, n_samples, mirror_samples, sample_domain):
Philipp Frank's avatar
Philipp Frank committed
215
        from .kl_energies import _get_lo_hi
216
217
        with random.Context(seed):
            sseq = random.spawn_sseq(n_samples)
Philipp Arras's avatar
Philipp Arras committed
218
        noise = []
Philipp Frank's avatar
Philipp Frank committed
219
220
        for i in range(*_get_lo_hi(comm, n_samples)):
            with random.Context(sseq[i]):
Philipp Arras's avatar
Philipp Arras committed
221
222
                rnd = from_random(sample_domain)
                noise.append(rnd)
Philipp Frank's avatar
Philipp Frank committed
223
                if mirror_samples:
Philipp Arras's avatar
Philipp Arras committed
224
                    noise.append(-rnd)
225

Philipp Arras's avatar
Philipp Arras committed
226
227
228
229
230
        return noise

    def samples(self):
        """Standard-normal samples that have been inserted into `op`"""
        return self._draw_noise(*self._context)