linearization.py 8.12 KB
Newer Older
1

Martin Reinecke's avatar
Martin Reinecke committed
2
3
4
5
6
7
from __future__ import absolute_import, division, print_function

import numpy as np

from .compat import *
from .field import Field
Martin Reinecke's avatar
Martin Reinecke committed
8
from .multi_field import MultiField
Martin Reinecke's avatar
Martin Reinecke committed
9
10
11
12
from .sugar import makeOp


class Linearization(object):
13
    def __init__(self, val, jac, metric=None, want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
14
15
        self._val = val
        self._jac = jac
Martin Reinecke's avatar
Martin Reinecke committed
16
17
        if self._val.domain != self._jac.target:
            raise ValueError("domain mismatch")
18
        self._want_metric = want_metric
Martin Reinecke's avatar
Martin Reinecke committed
19
20
        self._metric = metric

21
22
23
    def new(self, val, jac, metric=None):
        return Linearization(val, jac, metric, self._want_metric)

Martin Reinecke's avatar
Martin Reinecke committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    @property
    def domain(self):
        return self._jac.domain

    @property
    def target(self):
        return self._jac.target

    @property
    def val(self):
        return self._val

    @property
    def jac(self):
        return self._jac

Martin Reinecke's avatar
Martin Reinecke committed
40
41
42
    @property
    def gradient(self):
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
43
        return self._jac.adjoint_times(Field.scalar(1.))
Martin Reinecke's avatar
Martin Reinecke committed
44

45
46
47
48
    @property
    def want_metric(self):
        return self._want_metric

Martin Reinecke's avatar
Martin Reinecke committed
49
50
    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
51
        """Only available if target is a scalar"""
Martin Reinecke's avatar
Martin Reinecke committed
52
53
        return self._metric

Martin Reinecke's avatar
Martin Reinecke committed
54
    def __getitem__(self, name):
Martin Reinecke's avatar
Martin Reinecke committed
55
        from .operators.simple_linear_operators import FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
56
        return self.new(self._val[name], FieldAdapter(self.domain[name], name))
Martin Reinecke's avatar
Martin Reinecke committed
57

Martin Reinecke's avatar
Martin Reinecke committed
58
    def __neg__(self):
59
60
        return self.new(-self._val, -self._jac,
                        None if self._metric is None else -self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
61

Martin Reinecke's avatar
Martin Reinecke committed
62
    def conjugate(self):
63
        return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
64
65
66
67
68
            self._val.conjugate(), self._jac.conjugate(),
            None if self._metric is None else self._metric.conjugate())

    @property
    def real(self):
69
        return self.new(self._val.real, self._jac.real)
Martin Reinecke's avatar
Martin Reinecke committed
70

Martin Reinecke's avatar
Martin Reinecke committed
71
    def _myadd(self, other, neg):
Martin Reinecke's avatar
Martin Reinecke committed
72
73
74
        if isinstance(other, Linearization):
            met = None
            if self._metric is not None and other._metric is not None:
Martin Reinecke's avatar
Martin Reinecke committed
75
                met = self._metric._myadd(other._metric, neg)
76
            return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
77
78
                self._val.flexible_addsub(other._val, neg),
                self._jac._myadd(other._jac, neg), met)
Martin Reinecke's avatar
Martin Reinecke committed
79
        if isinstance(other, (int, float, complex, Field, MultiField)):
Martin Reinecke's avatar
Martin Reinecke committed
80
            if neg:
81
                return self.new(self._val-other, self._jac, self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
82
            else:
83
                return self.new(self._val+other, self._jac, self._metric)
Martin Reinecke's avatar
Martin Reinecke committed
84
85
86

    def __add__(self, other):
        return self._myadd(other, False)
Martin Reinecke's avatar
Martin Reinecke committed
87
88

    def __radd__(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
89
        return self._myadd(other, False)
Martin Reinecke's avatar
Martin Reinecke committed
90
91

    def __sub__(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
92
        return self._myadd(other, True)
Martin Reinecke's avatar
Martin Reinecke committed
93
94
95
96

    def __rsub__(self, other):
        return (-self).__add__(other)

97
98
99
100
101
102
103
104
    def __truediv__(self, other):
        if isinstance(other, Linearization):
            return self.__mul__(other.inverse())
        return self.__mul__(1./other)

    def __rtruediv__(self, other):
        return self.inverse().__mul__(other)

Martin Reinecke's avatar
Martin Reinecke committed
105
106
107
    def __pow__(self, power):
        if not np.isscalar(power):
            return NotImplemented
108
        return self.new(self._val**power, makeOp(self._val**(power-1)).scale(power)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
109

110
111
112
    def inverse(self):
        return self.new(1./self._val, makeOp(-1./(self._val**2))(self._jac))

Martin Reinecke's avatar
Martin Reinecke committed
113
114
115
    def __mul__(self, other):
        from .sugar import makeOp
        if isinstance(other, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
116
117
            if self.target != other.target:
                raise ValueError("domain mismatch")
118
            return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
119
                self._val*other._val,
Martin Reinecke's avatar
Martin Reinecke committed
120
121
                (makeOp(other._val)(self._jac))._myadd(
                 makeOp(self._val)(other._jac), False))
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
122
123
124
125
        if np.isscalar(other):
            if other == 1:
                return self
            met = None if self._metric is None else self._metric.scale(other)
126
            return self.new(self._val*other, self._jac.scale(other), met)
Martin Reinecke's avatar
Martin Reinecke committed
127
        if isinstance(other, (Field, MultiField)):
Martin Reinecke's avatar
Martin Reinecke committed
128
129
            if self.target != other.domain:
                raise ValueError("domain mismatch")
130
            return self.new(self._val*other, makeOp(other)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
131
132

    def __rmul__(self, other):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
133
        return self.__mul__(other)
Martin Reinecke's avatar
Martin Reinecke committed
134

135
136
137
138
    def outer(self, other):
        from .operators.outer_product_operator import OuterProduct
        if isinstance(other, Linearization):
            return self.new(
Sebastian Hutschenreuter's avatar
Sebastian Hutschenreuter committed
139
140
141
                OuterProduct(self._val, other.target)(other._val),
                OuterProduct(self._jac(self._val), other.target)._myadd(
                    OuterProduct(self._val, other.target)(other._jac), False))
142
        if np.isscalar(other):
Martin Reinecke's avatar
Martin Reinecke committed
143
            return self.__mul__(other)
144
        if isinstance(other, (Field, MultiField)):
Sebastian Hutschenreuter's avatar
Sebastian Hutschenreuter committed
145
146
            return self.new(OuterProduct(self._val, other.domain)(other),
                            OuterProduct(self._jac(self._val), other.domain))
147

Martin Reinecke's avatar
Martin Reinecke committed
148
    def vdot(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
149
        from .operators.simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
150
        if isinstance(other, (Field, MultiField)):
151
            return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
152
                Field.scalar(self._val.vdot(other)),
Martin Reinecke's avatar
Martin Reinecke committed
153
                VdotOperator(other)(self._jac))
154
        return self.new(
Martin Reinecke's avatar
Martin Reinecke committed
155
            Field.scalar(self._val.vdot(other._val)),
Martin Reinecke's avatar
Martin Reinecke committed
156
157
            VdotOperator(self._val)(other._jac) +
            VdotOperator(other._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
158

159
    def sum(self, spaces=None):
160
        from .operators.contraction_operator import ContractionOperator
161
162
163
        if spaces is None:
            return self.new(
                Field.scalar(self._val.sum()),
164
                ContractionOperator(self._jac.target, None)(self._jac))
165
166
167
        else:
            return self.new(
                self._val.sum(spaces),
168
                ContractionOperator(self._jac.target, spaces)(self._jac))
169
170

    def integrate(self, spaces=None):
171
        from .operators.contraction_operator import ContractionOperator
172
173
174
        if spaces is None:
            return self.new(
                Field.scalar(self._val.integrate()),
175
                ContractionOperator(self._jac.target, None, 1)(self._jac))
176
177
178
        else:
            return self.new(
                self._val.integrate(spaces),
179
                ContractionOperator(self._jac.target, spaces, 1)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
180
181
182

    def exp(self):
        tmp = self._val.exp()
183
        return self.new(tmp, makeOp(tmp)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
184
185
186

    def log(self):
        tmp = self._val.log()
187
        return self.new(tmp, makeOp(1./self._val)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
188

Martin Reinecke's avatar
Martin Reinecke committed
189
190
    def tanh(self):
        tmp = self._val.tanh()
191
        return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
192
193
194
195

    def positive_tanh(self):
        tmp = self._val.tanh()
        tmp2 = 0.5*(1.+tmp)
196
        return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
Martin Reinecke's avatar
Martin Reinecke committed
197

Martin Reinecke's avatar
Martin Reinecke committed
198
    def add_metric(self, metric):
199
        return self.new(self._val, self._jac, metric)
Martin Reinecke's avatar
Martin Reinecke committed
200

Martin Reinecke's avatar
Martin Reinecke committed
201
202
203
    def with_want_metric(self):
        return Linearization(self._val, self._jac, self._metric, True)

Martin Reinecke's avatar
Martin Reinecke committed
204
    @staticmethod
205
    def make_var(field, want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
206
        from .operators.scaling_operator import ScalingOperator
207
208
        return Linearization(field, ScalingOperator(1., field.domain),
                             want_metric=want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
209
210

    @staticmethod
211
    def make_const(field, want_metric=False):
Martin Reinecke's avatar
Martin Reinecke committed
212
        from .operators.simple_linear_operators import NullOperator
213
214
        return Linearization(field, NullOperator(field.domain, field.domain),
                             want_metric=want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
215

Martin Reinecke's avatar
Martin Reinecke committed
216
217
218
219
    @staticmethod
    def make_const_empty_input(field, want_metric=False):
        from .operators.simple_linear_operators import NullOperator
        from .multi_domain import MultiDomain
Martin Reinecke's avatar
Martin Reinecke committed
220
221
222
        return Linearization(
            field, NullOperator(MultiDomain.make({}), field.domain),
            want_metric=want_metric)
Martin Reinecke's avatar
Martin Reinecke committed
223

Martin Reinecke's avatar
Martin Reinecke committed
224
225
226
    @staticmethod
    def make_partial_var(field, constants, want_metric=False):
        from .operators.scaling_operator import ScalingOperator
Philipp Arras's avatar
Typo    
Philipp Arras committed
227
        from .operators.block_diagonal_operator import BlockDiagonalOperator
Martin Reinecke's avatar
Martin Reinecke committed
228
229
230
231
232
        if len(constants) == 0:
            return Linearization.make_var(field, want_metric)
        else:
            ops = [ScalingOperator(0. if key in constants else 1., dom)
                   for key, dom in field.domain.items()]
Philipp Arras's avatar
Typo    
Philipp Arras committed
233
            bdop = BlockDiagonalOperator(field.domain, tuple(ops))
Martin Reinecke's avatar
Martin Reinecke committed
234
            return Linearization(field, bdop, want_metric=want_metric)