extra.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

Philipp Arras's avatar
Philipp Arras committed
18
19
from itertools import combinations

20
import numpy as np
21
from numpy.testing import assert_
Philipp Arras's avatar
Philipp Arras committed
22

23
from . import random
24
from .domain_tuple import DomainTuple
Martin Reinecke's avatar
fix    
Martin Reinecke committed
25
26
from .field import Field
from .linearization import Linearization
27
from .multi_domain import MultiDomain
28
from .multi_field import MultiField
29
from .operators.energy_operators import EnergyOperator
30
from .operators.linear_operator import LinearOperator
Philipp Arras's avatar
Philipp Arras committed
31
from .operators.operator import Operator
32
from .sugar import from_random
33

Philipp Arras's avatar
Philipp Arras committed
34
__all__ = ["check_linear_operator", "check_operator",
Philipp Arras's avatar
Philipp Arras committed
35
           "assert_allclose"]
36

Philipp Arras's avatar
Philipp Arras committed
37

Philipp Arras's avatar
Philipp Arras committed
38
def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
Philipp Arras's avatar
Philipp Arras committed
39
                          atol=1e-12, rtol=1e-12, only_r_linear=False):
Philipp Arras's avatar
Philipp Arras committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    """
    Checks an operator for algebraic consistency of its capabilities.

    Checks whether times(), adjoint_times(), inverse_times() and
    adjoint_inverse_times() (if in capability list) is implemented
    consistently. Additionally, it checks whether the operator is linear.

    Parameters
    ----------
    op : LinearOperator
        Operator which shall be checked.
    domain_dtype : dtype
        The data type of the random vectors in the operator's domain. Default
        is `np.float64`.
    target_dtype : dtype
        The data type of the random vectors in the operator's target. Default
        is `np.float64`.
    atol : float
        Absolute tolerance for the check. If rtol is specified,
        then satisfying any tolerance will let the check pass.
        Default: 0.
    rtol : float
        Relative tolerance for the check. If atol is specified,
        then satisfying any tolerance will let the check pass.
        Default: 0.
    only_r_linear: bool
        set to True if the operator is only R-linear, not C-linear.
        This will relax the adjointness test accordingly.
    """
    if not isinstance(op, LinearOperator):
        raise TypeError('This test tests only linear operators.')
    _domain_check_linear(op, domain_dtype)
    _domain_check_linear(op.adjoint, target_dtype)
    _domain_check_linear(op.inverse, target_dtype)
    _domain_check_linear(op.adjoint.inverse, domain_dtype)
    _check_linearity(op, domain_dtype, atol, rtol)
    _check_linearity(op.adjoint, target_dtype, atol, rtol)
    _check_linearity(op.inverse, target_dtype, atol, rtol)
    _check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol)
    _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
                         only_r_linear)
    _full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol,
                         only_r_linear)
    _full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol,
                         only_r_linear)
    _full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
                         rtol, only_r_linear)


Philipp Arras's avatar
Philipp Arras committed
89
def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
90
                   only_r_differentiable=True, metric_sampling=True):
Philipp Arras's avatar
Philipp Arras committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    """
    Performs various checks of the implementation of linear and nonlinear
    operators.

    Computes the Jacobian with finite differences and compares it to the
    implemented Jacobian.

    Parameters
    ----------
    op : Operator
        Operator which shall be checked.
    loc : Field or MultiField
        An Field or MultiField instance which has the same domain
        as op. The location at which the gradient is checked
    tol : float
        Tolerance for the check.
    perf_check : Boolean
        Do performance check. May be disabled for very unimportant operators.
    only_r_differentiable : Boolean
        Jacobians of C-differentiable operators need to be C-linear.
        Default: True
    metric_sampling: Boolean
        If op is an EnergyOperator, metric_sampling determines whether the
        test shall try to sample from the metric or not.
    """
    if not isinstance(op, Operator):
        raise TypeError('This test tests only linear operators.')
    _domain_check_nonlinear(op, loc)
    _performance_check(op, loc, bool(perf_check))
    _linearization_value_consistency(op, loc)
121
122
    _jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries,
                               only_r_differentiable)
Philipp Arras's avatar
Philipp Arras committed
123
    _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
124
                               metric_sampling)
Philipp Arras's avatar
Philipp Arras committed
125
126


Philipp Arras's avatar
Philipp Arras committed
127
def assert_allclose(f1, f2, atol, rtol):
Martin Reinecke's avatar
Martin Reinecke committed
128
    if isinstance(f1, Field):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
129
        return np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol)
Martin Reinecke's avatar
Martin Reinecke committed
130
    for key, val in f1.items():
Philipp Arras's avatar
Philipp Arras committed
131
        assert_allclose(val, f2[key], atol=atol, rtol=rtol)
Martin Reinecke's avatar
Martin Reinecke committed
132
133


Philipp Arras's avatar
Philipp Arras committed
134
135
136
137
138
139
140
def assert_equal(f1, f2):
    if isinstance(f1, Field):
        return np.testing.assert_equal(f1.val, f2.val)
    for key, val in f1.items():
        assert_equal(val, f2[key])


Philipp Arras's avatar
Philipp Arras committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def _nozero(fld):
    if isinstance(fld, Field):
        return np.testing.assert_((fld != 0).s_all())
    for val in fld.values():
        _nozero(val)


def _allzero(fld):
    if isinstance(fld, Field):
        return np.testing.assert_((fld == 0.).s_all())
    for val in fld.values():
        _allzero(val)


155
156
def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
                            only_r_linear):
Martin Reinecke's avatar
Martin Reinecke committed
157
158
159
    needed_cap = op.TIMES | op.ADJOINT_TIMES
    if (op.capability & needed_cap) != needed_cap:
        return
160
161
    f1 = from_random(op.domain, "normal", dtype=domain_dtype)
    f2 = from_random(op.target, "normal", dtype=target_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
162
163
    res1 = f1.s_vdot(op.adjoint_times(f2))
    res2 = op.times(f1).s_vdot(f2)
164
165
    if only_r_linear:
        res1, res2 = res1.real, res2.real
Martin Reinecke's avatar
Martin Reinecke committed
166
167
168
169
170
171
172
    np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)


def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
    needed_cap = op.TIMES | op.INVERSE_TIMES
    if (op.capability & needed_cap) != needed_cap:
        return
173
    foo = from_random(op.target, "normal", dtype=target_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
174
    res = op(op.inverse_times(foo))
Philipp Arras's avatar
Philipp Arras committed
175
    assert_allclose(res, foo, atol=atol, rtol=rtol)
Martin Reinecke's avatar
Martin Reinecke committed
176

177
    foo = from_random(op.domain, "normal", dtype=domain_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
178
    res = op.inverse_times(op(foo))
Philipp Arras's avatar
Philipp Arras committed
179
    assert_allclose(res, foo, atol=atol, rtol=rtol)
Martin Reinecke's avatar
Martin Reinecke committed
180
181


182
183
184
185
def _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
                         only_r_linear):
    _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
                            only_r_linear)
Martin Reinecke's avatar
Martin Reinecke committed
186
187
188
    _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)


189
def _check_linearity(op, domain_dtype, atol, rtol):
Martin Reinecke's avatar
Martin Reinecke committed
190
191
192
    needed_cap = op.TIMES
    if (op.capability & needed_cap) != needed_cap:
        return
193
194
    fld1 = from_random(op.domain, "normal", dtype=domain_dtype)
    fld2 = from_random(op.domain, "normal", dtype=domain_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
195
    alpha = 0.42
196
197
    val1 = op(alpha*fld1+fld2)
    val2 = alpha*op(fld1)+op(fld2)
Philipp Arras's avatar
Philipp Arras committed
198
    assert_allclose(val1, val2, atol=atol, rtol=rtol)
199
200


Philipp Arras's avatar
Philipp Arras committed
201
202
def _domain_check_linear(op, domain_dtype=None, inp=None):
    _domain_check(op)
203
204
205
206
    needed_cap = op.TIMES
    if (op.capability & needed_cap) != needed_cap:
        return
    if domain_dtype is not None:
207
        inp = from_random(op.domain, "normal", dtype=domain_dtype)
208
209
210
211
212
213
    elif inp is None:
        raise ValueError('Need to specify either dtype or inp')
    assert_(inp.domain is op.domain)
    assert_(op(inp).domain is op.target)


Philipp Arras's avatar
Philipp Arras committed
214
215
216
def _domain_check_nonlinear(op, loc):
    _domain_check(op)
    assert_(isinstance(loc, (Field, MultiField)))
217
    assert_(loc.domain is op.domain)
Philipp Arras's avatar
Philipp Arras committed
218
219
220
221
222
223
224
225
226
227
228
229
    for wm in [False, True]:
        lin = Linearization.make_var(loc, wm)
        reslin = op(lin)
        assert_(lin.domain is op.domain)
        assert_(lin.target is op.domain)
        assert_(lin.val.domain is lin.domain)
        assert_(reslin.domain is op.domain)
        assert_(reslin.target is op.target)
        assert_(reslin.val.domain is reslin.target)
        assert_(reslin.target is op.target)
        assert_(reslin.jac.domain is reslin.domain)
        assert_(reslin.jac.target is reslin.target)
Philipp Arras's avatar
Philipp Arras committed
230
        assert_(lin.want_metric == reslin.want_metric)
Philipp Arras's avatar
Philipp Arras committed
231
232
        _domain_check_linear(reslin.jac, inp=loc)
        _domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc))
Philipp Arras's avatar
Philipp Arras committed
233
        if reslin.metric is not None:
Philipp Arras's avatar
Philipp Arras committed
234
235
            assert_(reslin.metric.domain is reslin.metric.target)
            assert_(reslin.metric.domain is op.domain)
236
237


238
239
240
def _domain_check(op):
    for dd in [op.domain, op.target]:
        if not isinstance(dd, (DomainTuple, MultiDomain)):
Martin Reinecke's avatar
Martin Reinecke committed
241
242
243
            raise TypeError(
                'The domain and the target of an operator need to',
                'be instances of either DomainTuple or MultiDomain.')
244
245


246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def _performance_check(op, pos, raise_on_fail):
    class CountingOp(LinearOperator):
        def __init__(self, domain):
            from .sugar import makeDomain
            self._domain = self._target = makeDomain(domain)
            self._capability = self.TIMES | self.ADJOINT_TIMES
            self._count = 0

        def apply(self, x, mode):
            self._count += 1
            return x

        @property
        def count(self):
            return self._count
Philipp Arras's avatar
Philipp Arras committed
261
262
    for wm in [False, True]:
        cop = CountingOp(op.domain)
Philipp Arras's avatar
Philipp Arras committed
263
264
        myop = op @ cop
        myop(pos)
Philipp Arras's avatar
Philipp Arras committed
265
        cond = [cop.count != 1]
Philipp Arras's avatar
Philipp Arras committed
266
        lin = myop(2*Linearization.make_var(pos, wm))
Philipp Arras's avatar
Philipp Arras committed
267
268
269
270
271
        cond.append(cop.count != 2)
        lin.jac(pos)
        cond.append(cop.count != 3)
        lin.jac.adjoint(lin.val)
        cond.append(cop.count != 4)
Philipp Arras's avatar
Philipp Arras committed
272
        if lin.metric is not None:
Philipp Arras's avatar
Philipp Arras committed
273
274
275
276
277
278
279
280
281
            lin.metric(pos)
            cond.append(cop.count != 6)
        if any(cond):
            s = 'The operator has a performance problem (want_metric={}).'.format(wm)
            from .logger import logger
            logger.error(s)
            logger.info(cond)
            if raise_on_fail:
                raise RuntimeError(s)
282
283


Martin Reinecke's avatar
Martin Reinecke committed
284
def _get_acceptable_location(op, loc, lin):
Martin Reinecke's avatar
Martin Reinecke committed
285
    if not np.isfinite(lin.val.s_sum()):
Martin Reinecke's avatar
Martin Reinecke committed
286
        raise ValueError('Initial value must be finite')
287
    dir = from_random(loc.domain, dtype=loc.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
288
289
    dirder = lin.jac(dir)
    if dirder.norm() == 0:
Martin Reinecke's avatar
Martin Reinecke committed
290
        dir = dir * (lin.val.norm()*1e-5)
Martin Reinecke's avatar
Martin Reinecke committed
291
    else:
Martin Reinecke's avatar
Martin Reinecke committed
292
        dir = dir * (lin.val.norm()*1e-5/dirder.norm())
Martin Reinecke's avatar
Martin Reinecke committed
293
294
295
296
    # Find a step length that leads to a "reasonable" location
    for i in range(50):
        try:
            loc2 = loc+dir
297
            lin2 = op(Linearization.make_var(loc2, lin.want_metric))
Martin Reinecke's avatar
Martin Reinecke committed
298
            if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20:
Martin Reinecke's avatar
Martin Reinecke committed
299
300
301
302
303
304
305
306
                break
        except FloatingPointError:
            pass
        dir = dir*0.5
    else:
        raise ValueError("could not find a reasonable initial step")
    return loc2, lin2

Martin Reinecke's avatar
Martin Reinecke committed
307

308
309
310
311
312
313
314
315
def _linearization_value_consistency(op, loc):
    for wm in [False, True]:
        lin = Linearization.make_var(loc, wm)
        fld0 = op(loc)
        fld1 = op(lin).val
        assert_allclose(fld0, fld1, 0, 1e-7)


Philipp Arras's avatar
Philipp Arras committed
316
def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
317
                               metric_sampling):
Philipp Arras's avatar
Philipp Arras committed
318
    if isinstance(op.domain, DomainTuple):
319
        return
Philipp Arras's avatar
Philipp Arras committed
320
    keys = op.domain.keys()
321
    combis = []
322
    if len(keys) > 4:
323
        from .logger import logger
324
        logger.warning('Operator domain has more than 4 keys.')
325
        logger.warning('Check derivatives only with one constant key at a time.')
326
        combis = [[kk] for kk in keys]
327
328
329
    else:
        for ll in range(1, len(keys)):
            combis.extend(list(combinations(keys, ll)))
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    for cstkeys in combis:
        varkeys = set(keys) - set(cstkeys)
        cstloc = loc.extract_by_keys(cstkeys)
        varloc = loc.extract_by_keys(varkeys)

        val0 = op(loc)
        _, op0 = op.simplify_for_constant_input(cstloc)
        assert op0.domain is varloc.domain
        val1 = op0(varloc)
        assert_equal(val0, val1)

        lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True)
        lin0 = Linearization.make_var(varloc, want_metric=True)
        oplin0 = op0(lin0)
        oplin = op(lin)

        assert oplin.jac.target is oplin0.jac.target
        rndinp = from_random(oplin.jac.target)
348
349
        assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain),
                        oplin0.jac.adjoint(rndinp), 1e-13, 1e-13)
350
351
352
        foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
        assert_equal(foo, 0*foo)

353
354
        if isinstance(op, EnergyOperator) and metric_sampling:
            oplin.metric.draw_sample()
355

356
357
        # _jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries,
        #                            only_r_differentiable)
Philipp Arras's avatar
Philipp Arras committed
358
359
360


def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
Martin Reinecke's avatar
Martin Reinecke committed
361
    for _ in range(ntries):
362
        lin = op(Linearization.make_var(loc))
Martin Reinecke's avatar
Martin Reinecke committed
363
        loc2, lin2 = _get_acceptable_location(op, loc, lin)
Martin Reinecke's avatar
Martin Reinecke committed
364
        dir = loc2-loc
Martin Reinecke's avatar
Martin Reinecke committed
365
366
        locnext = loc2
        dirnorm = dir.norm()
Martin Reinecke's avatar
Martin Reinecke committed
367
        hist = []
Martin Reinecke's avatar
Martin Reinecke committed
368
369
        for i in range(50):
            locmid = loc + 0.5*dir
370
            linmid = op(Linearization.make_var(locmid))
Martin Reinecke's avatar
Martin Reinecke committed
371
372
            dirder = linmid.jac(dir)
            numgrad = (lin2.val-lin.val)
Martin Reinecke's avatar
Martin Reinecke committed
373
            xtol = tol * dirder.norm() / np.sqrt(dirder.size)
Martin Reinecke's avatar
Martin Reinecke committed
374
375
            hist.append((numgrad-dirder).norm())
#            print(len(hist),hist[-1])
Martin Reinecke's avatar
Martin Reinecke committed
376
            if (abs(numgrad-dirder) <= xtol).s_all():
Martin Reinecke's avatar
Martin Reinecke committed
377
378
379
                break
            dir = dir*0.5
            dirnorm *= 0.5
Martin Reinecke's avatar
Martin Reinecke committed
380
            loc2, lin2 = locmid, linmid
Martin Reinecke's avatar
Martin Reinecke committed
381
        else:
Martin Reinecke's avatar
Martin Reinecke committed
382
            print(hist)
Martin Reinecke's avatar
Martin Reinecke committed
383
384
            raise ValueError("gradient and value seem inconsistent")
        loc = locnext
Philipp Arras's avatar
Philipp Arras committed
385
386
        check_linear_operator(linmid.jac, domain_dtype=loc.dtype,
                              target_dtype=dirder.dtype,
387
388
                              only_r_linear=only_r_differentiable,
                              atol=tol**2, rtol=tol**2)