linearization.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17
18
19
20

import numpy as np

from .field import Field
Martin Reinecke's avatar
Martin Reinecke committed
21
from .multi_field import MultiField
Martin Reinecke's avatar
Martin Reinecke committed
22
from .sugar import makeOp
Jakob Knollmueller's avatar
Jakob Knollmueller committed
23
from .operators.scaling_operator import ScalingOperator
Martin Reinecke's avatar
Martin Reinecke committed
24
25
26


class Linearization(object):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
27
    """Let `A` be an operator and `x` a field. `Linearization` stores the value
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    of the operator application (i.e. `A(x)`), the local Jacobian
    (i.e. `dA(x)/dx`) and, optionally, the local metric.

    Parameters
    ----------
    val : Field/MultiField
        the value of the operator application
    jac : LinearOperator
        the Jacobian
    metric : LinearOperator or None (default: None)
        the metric
    want_metric : bool (default: False)
        if True, the metric will be computed for other Linearizations derived
        from this one.
    """
43
    def __init__(self, val, jac, metric=None, want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
44
45
        self._val = val
        self._jac = jac
Martin Reinecke's avatar
Martin Reinecke committed
46
47
        if self._val.domain != self._jac.target:
            raise ValueError("domain mismatch")
48
        self._want_metric = want_metric
Martin Reinecke's avatar
Martin Reinecke committed
49
50
        self._metric = metric

51
    def new(self, val, jac, metric=None):
52
53
54
55
56
57
58
59
60
61
62
63
        """Create a new Linearization, taking the `want_metric` property from
           this one.

        Parameters
        ----------
        val : Field/MultiField
            the value of the operator application
        jac : LinearOperator
            the Jacobian
        metric : LinearOperator or None (default: None)
            the metric
        """
64
65
        return Linearization(val, jac, metric, self._want_metric)

Martin Reinecke's avatar
Martin Reinecke committed
66
67
    @property
    def domain(self):
68
        """DomainTuple/MultiDomain : the Jacobian's domain"""
Martin Reinecke's avatar
Martin Reinecke committed
69
70
71
72
        return self._jac.domain

    @property
    def target(self):
73
        """DomainTuple/MultiDomain : the Jacobian's target (i.e. the value's domain)"""
Martin Reinecke's avatar
Martin Reinecke committed
74
75
76
77
        return self._jac.target

    @property
    def val(self):
78
        """Field/MultiField : the value"""
Martin Reinecke's avatar
Martin Reinecke committed
79
80
81
82
        return self._val

    @property
    def jac(self):
83
        """LinearOperator : the Jacobian"""
Martin Reinecke's avatar
Martin Reinecke committed
84
85
        return self._jac

Martin Reinecke's avatar
Martin Reinecke committed
86
87
    @property
    def gradient(self):
88
89
90
91
92
93
        """Field/MultiField : the gradient

        Notes
        -----
        Only available if target is a scalar
        """
Martin Reinecke's avatar
Martin Reinecke committed
94
        return self._jac.adjoint_times(Field.scalar(1.))
Martin Reinecke's avatar
Martin Reinecke committed
95

96
97
    @property
    def want_metric(self):
98
        """bool : the value of `want_metric`"""
99
100
        return self._want_metric

Martin Reinecke's avatar
Martin Reinecke committed
101
102
    @property
    def metric(self):
103
104
105
106
107
108
        """LinearOperator : the metric

        Notes
        -----
        Only available if target is a scalar
        """
Martin Reinecke's avatar
Martin Reinecke committed
109
110
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
111
    def __getitem__(self, name):
Martin Reinecke's avatar
Martin Reinecke committed
112
113
        from .operators.simple_linear_operators import ducktape
        return self.new(self._val[name], ducktape(None, self.domain, name))
Martin Reinecke's avatar
Martin Reinecke committed
114

Martin Reinecke's avatar
Martin Reinecke committed
115
    def __neg__(self):
116
117
        return self.new(-self._val, -self._jac,
                        None if self._metric is None else -self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
118

Martin Reinecke's avatar
Martin Reinecke committed
119
    def conjugate(self):
120
        return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
121
122
123
124
125
            self._val.conjugate(), self._jac.conjugate(),
            None if self._metric is None else self._metric.conjugate())

    @property
    def real(self):
126
        return self.new(self._val.real, self._jac.real)
Martin Reinecke's avatar
Martin Reinecke committed
127

Martin Reinecke's avatar
Martin Reinecke committed
128
    def _myadd(self, other, neg):
Martin Reinecke's avatar
Martin Reinecke committed
129
130
131
        if isinstance(other, Linearization):
            met = None
            if self._metric is not None and other._metric is not None:
Martin Reinecke's avatar
Martin Reinecke committed
132
                met = self._metric._myadd(other._metric, neg)
133
            return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
134
135
                self._val.flexible_addsub(other._val, neg),
                self._jac._myadd(other._jac, neg), met)
Martin Reinecke's avatar
Martin Reinecke committed
136
        if isinstance(other, (int, float, complex, Field, MultiField)):
Martin Reinecke's avatar
Martin Reinecke committed
137
            if neg:
138
                return self.new(self._val-other, self._jac, self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
139
            else:
140
                return self.new(self._val+other, self._jac, self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
141
142
143

    def __add__(self, other):
        return self._myadd(other, False)
Martin Reinecke's avatar
Martin Reinecke committed
144
145

    def __radd__(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
146
        return self._myadd(other, False)
Martin Reinecke's avatar
Martin Reinecke committed
147
148

    def __sub__(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
149
        return self._myadd(other, True)
Martin Reinecke's avatar
Martin Reinecke committed
150
151
152
153

    def __rsub__(self, other):
        return (-self).__add__(other)

154
155
    def __truediv__(self, other):
        if isinstance(other, Linearization):
Philipp Frank's avatar
Philipp Frank committed
156
            return self.__mul__(other.one_over())
157
158
159
        return self.__mul__(1./other)

    def __rtruediv__(self, other):
Philipp Frank's avatar
Philipp Frank committed
160
        return self.one_over().__mul__(other)
161

Martin Reinecke's avatar
Martin Reinecke committed
162
163
164
    def __pow__(self, power):
        if not np.isscalar(power):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
165
166
        return self.new(self._val**power,
                        makeOp(self._val**(power-1)).scale(power)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
167

Martin Reinecke's avatar
Martin Reinecke committed
168
169
170
    def __mul__(self, other):
        from .sugar import makeOp
        if isinstance(other, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
171
172
            if self.target != other.target:
                raise ValueError("domain mismatch")
173
            return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
174
                self._val*other._val,
Martin Reinecke's avatar
Martin Reinecke committed
175
176
                (makeOp(other._val)(self._jac))._myadd(
                 makeOp(self._val)(other._jac), False))
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
177
178
179
180
        if np.isscalar(other):
            if other == 1:
                return self
            met = None if self._metric is None else self._metric.scale(other)
181
            return self.new(self._val*other, self._jac.scale(other), met)
Martin Reinecke's avatar
Martin Reinecke committed
182
        if isinstance(other, (Field, MultiField)):
Martin Reinecke's avatar
Martin Reinecke committed
183
184
            if self.target != other.domain:
                raise ValueError("domain mismatch")
185
            return self.new(self._val*other, makeOp(other)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
186
187

    def __rmul__(self, other):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
188
        return self.__mul__(other)
Martin Reinecke's avatar
Martin Reinecke committed
189

190
191
192
193
    def outer(self, other):
        from .operators.outer_product_operator import OuterProduct
        if isinstance(other, Linearization):
            return self.new(
Sebastian Hutschenreuter's avatar
Sebastian Hutschenreuter committed
194
195
196
                OuterProduct(self._val, other.target)(other._val),
                OuterProduct(self._jac(self._val), other.target)._myadd(
                    OuterProduct(self._val, other.target)(other._jac), False))
197
        if np.isscalar(other):
Martin Reinecke's avatar
Martin Reinecke committed
198
            return self.__mul__(other)
199
        if isinstance(other, (Field, MultiField)):
Sebastian Hutschenreuter's avatar
Sebastian Hutschenreuter committed
200
201
            return self.new(OuterProduct(self._val, other.domain)(other),
                            OuterProduct(self._jac(self._val), other.domain))
202

Martin Reinecke's avatar
Martin Reinecke committed
203
    def vdot(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
204
        from .operators.simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
205
        if isinstance(other, (Field, MultiField)):
206
            return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
207
                Field.scalar(self._val.vdot(other)),
Martin Reinecke's avatar
Martin Reinecke committed
208
                VdotOperator(other)(self._jac))
209
        return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
210
            Field.scalar(self._val.vdot(other._val)),
Martin Reinecke's avatar
Martin Reinecke committed
211
212
            VdotOperator(self._val)(other._jac) +
            VdotOperator(other._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
213

214
    def sum(self, spaces=None):
215
        from .operators.contraction_operator import ContractionOperator
216
217
218
        if spaces is None:
            return self.new(
                Field.scalar(self._val.sum()),
219
                ContractionOperator(self._jac.target, None)(self._jac))
220
221
222
        else:
            return self.new(
                self._val.sum(spaces),
223
                ContractionOperator(self._jac.target, spaces)(self._jac))
224
225

    def integrate(self, spaces=None):
226
        from .operators.contraction_operator import ContractionOperator
227
228
229
        if spaces is None:
            return self.new(
                Field.scalar(self._val.integrate()),
230
                ContractionOperator(self._jac.target, None, 1)(self._jac))
231
232
233
        else:
            return self.new(
                self._val.integrate(spaces),
234
                ContractionOperator(self._jac.target, spaces, 1)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
235
236
237

    def exp(self):
        tmp = self._val.exp()
238
        return self.new(tmp, makeOp(tmp)(self._jac))
Philipp Arras's avatar
Philipp Arras committed
239

Martin Reinecke's avatar
Martin Reinecke committed
240
241
    def clip(self, min=None, max=None):
        tmp = self._val.clip(min, max)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
242
        if (min is None) and (max is None):
Martin Reinecke's avatar
Martin Reinecke committed
243
            return self
Jakob Knollmueller's avatar
Jakob Knollmueller committed
244
245
246
247
248
249
        elif max is None:
            tmp2 = makeOp(1. - (tmp == min))
        elif min is None:
            tmp2 = makeOp(1. - (tmp == max))
        else:
            tmp2 = makeOp(1. - (tmp == min) - (tmp == max))
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        return self.new(tmp, tmp2(self._jac))

    def sin(self):
        tmp = self._val.sin()
        tmp2 = self._val.cos()
        return self.new(tmp, makeOp(tmp2)(self._jac))

    def cos(self):
        tmp = self._val.cos()
        tmp2 = - self._val.sin()
        return self.new(tmp, makeOp(tmp2)(self._jac))

    def tan(self):
        tmp = self._val.tan()
        tmp2 = 1./(self._val.cos()**2)
        return self.new(tmp, makeOp(tmp2)(self._jac))

    def sinc(self):
        tmp = self._val.sinc()
        tmp2 = (self._val.cos()-tmp)/self._val
        return self.new(tmp, makeOp(tmp2)(self._jac))

Martin Reinecke's avatar
Martin Reinecke committed
272
273
    def log(self):
        tmp = self._val.log()
274
        return self.new(tmp, makeOp(1./self._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
275

276
277
278
279
280
281
282
283
284
285
    def sinh(self):
        tmp = self._val.sinh()
        tmp2 = self._val.cosh()
        return self.new(tmp, makeOp(tmp2)(self._jac))

    def cosh(self):
        tmp = self._val.cosh()
        tmp2 = self._val.sinh()
        return self.new(tmp, makeOp(tmp2)(self._jac))

Martin Reinecke's avatar
Martin Reinecke committed
286
287
    def tanh(self):
        tmp = self._val.tanh()
288
        return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
289

290
    def sigmoid(self):
Martin Reinecke's avatar
Martin Reinecke committed
291
292
        tmp = self._val.tanh()
        tmp2 = 0.5*(1.+tmp)
293
        return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
294

295
296
297
298
299
300
301
302
303
304
    def absolute(self):
        tmp = self._val.absolute()
        tmp2 = self._val.sign()
        return self.new(tmp, makeOp(tmp2)(self._jac))

    def one_over(self):
        tmp = 1./self._val
        tmp2 = - tmp/self._val
        return self.new(tmp, makeOp(tmp2)(self._jac))

Martin Reinecke's avatar
Martin Reinecke committed
305
    def add_metric(self, metric):
306
        return self.new(self._val, self._jac, metric)
Martin Reinecke's avatar
Martin Reinecke committed
307

Martin Reinecke's avatar
Martin Reinecke committed
308
309
310
    def with_want_metric(self):
        return Linearization(self._val, self._jac, self._metric, True)

Martin Reinecke's avatar
Martin Reinecke committed
311
    @staticmethod
312
    def make_var(field, want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
313
        from .operators.scaling_operator import ScalingOperator
314
315
        return Linearization(field, ScalingOperator(1., field.domain),
                             want_metric=want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
316
317

    @staticmethod
318
    def make_const(field, want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
319
        from .operators.simple_linear_operators import NullOperator
320
321
        return Linearization(field, NullOperator(field.domain, field.domain),
                             want_metric=want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
322

Martin Reinecke's avatar
Martin Reinecke committed
323
324
325
326
    @staticmethod
    def make_const_empty_input(field, want_metric=False):
        from .operators.simple_linear_operators import NullOperator
        from .multi_domain import MultiDomain
Martin Reinecke's avatar
Martin Reinecke committed
327
328
329
        return Linearization(
            field, NullOperator(MultiDomain.make({}), field.domain),
            want_metric=want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
330

Martin Reinecke's avatar
Martin Reinecke committed
331
332
333
    @staticmethod
    def make_partial_var(field, constants, want_metric=False):
        from .operators.scaling_operator import ScalingOperator
Philipp Arras's avatar
Typo    
Philipp Arras committed
334
        from .operators.block_diagonal_operator import BlockDiagonalOperator
Martin Reinecke's avatar
Martin Reinecke committed
335
336
337
338
339
        if len(constants) == 0:
            return Linearization.make_var(field, want_metric)
        else:
            ops = [ScalingOperator(0. if key in constants else 1., dom)
                   for key, dom in field.domain.items()]
Philipp Arras's avatar
Typo    
Philipp Arras committed
340
            bdop = BlockDiagonalOperator(field.domain, tuple(ops))
Martin Reinecke's avatar
Martin Reinecke committed
341
            return Linearization(field, bdop, want_metric=want_metric)