operator.py 8.11 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
from __future__ import absolute_import, division, print_function

Martin Reinecke's avatar
Martin Reinecke committed
3
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
4
from ..compat import *
Philipp Arras's avatar
Philipp Arras committed
5
from ..utilities import NiftyMetaBase, indent
Martin Reinecke's avatar
Martin Reinecke committed
6
7
8
9
10
11
12


class Operator(NiftyMetaBase()):
    """Transforms values living on one domain into values living on another
    domain, and can also provide the Jacobian.
    """

Martin Reinecke's avatar
Martin Reinecke committed
13
    @property
Martin Reinecke's avatar
Martin Reinecke committed
14
15
16
17
    def domain(self):
        """DomainTuple or MultiDomain : the operator's input domain

            The domain on which the Operator's input Field lives."""
Martin Reinecke's avatar
Martin Reinecke committed
18
        return self._domain
Martin Reinecke's avatar
Martin Reinecke committed
19

Martin Reinecke's avatar
Martin Reinecke committed
20
    @property
Martin Reinecke's avatar
Martin Reinecke committed
21
22
23
24
    def target(self):
        """DomainTuple or MultiDomain : the operator's output domain

            The domain on which the Operator's output Field lives."""
Martin Reinecke's avatar
Martin Reinecke committed
25
        return self._target
Martin Reinecke's avatar
Martin Reinecke committed
26

Martin Reinecke's avatar
Martin Reinecke committed
27
28
29
30
31
32
    @staticmethod
    def _check_domain_equality(dom_op, dom_field):
        if dom_op != dom_field:
            s = "The operator's and field's domains don't match."
            from ..domain_tuple import DomainTuple
            from ..multi_domain import MultiDomain
Sebastian Hutschenreuter's avatar
fix  
Sebastian Hutschenreuter committed
33
            if not isinstance(dom_op, (DomainTuple, MultiDomain,)):
Martin Reinecke's avatar
Martin Reinecke committed
34
35
36
37
                s += " Your operator's domain is neither a `DomainTuple`" \
                     " nor a `MultiDomain`."
            raise ValueError(s)

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
38
39
40
41
42
43
44
45
46
47
    def scale(self, factor):
        if factor == 1:
            return self
        from .scaling_operator import ScalingOperator
        return ScalingOperator(factor, self.target)(self)

    def conjugate(self):
        from .simple_linear_operators import ConjugationOperator
        return ConjugationOperator(self.target)(self)

Martin Reinecke's avatar
Martin Reinecke committed
48
49
50
51
52
    @property
    def real(self):
        from .simple_linear_operators import Realizer
        return Realizer(self.target)(self)

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
53
54
55
    def __neg__(self):
        return self.scale(-1)

Martin Reinecke's avatar
Martin Reinecke committed
56
57
58
    def __matmul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
59
        return _OpChain.make((self, x))
Martin Reinecke's avatar
Martin Reinecke committed
60

Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
    def __mul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
64
        return _OpProd(self, x)
Martin Reinecke's avatar
Martin Reinecke committed
65

Philipp Arras's avatar
Philipp Arras committed
66
67
68
    def __add__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
69
        return _OpSum(self, x)
Philipp Arras's avatar
Philipp Arras committed
70

71
72
73
74
75
    def __sub__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
        return _OpSum(self, -x)

Martin Reinecke's avatar
Martin Reinecke committed
76
77
78
79
80
    def __pow__(self, power):
        if not np.isscalar(power):
            return NotImplemented
        return _OpChain.make((_PowerOp(self.target, power), self))

Martin Reinecke's avatar
Martin Reinecke committed
81
82
83
84
85
    def clip(self, min=None, max=None):
        if min is None and max is None:
            return self
        return _OpChain.make((_Clipper(sef.target, min, max), self))

Martin Reinecke's avatar
Martin Reinecke committed
86
87
    def apply(self, x):
        raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
88

Philipp Arras's avatar
Philipp Arras committed
89
    def force(self, x):
Philipp Arras's avatar
Philipp Arras committed
90
        """Extract correct subset of domain of x and apply operator."""
Philipp Arras's avatar
Philipp Arras committed
91
92
        return self.apply(x.extract(self.domain))

93
94
95
    def _check_input(self, x):
        from ..linearization import Linearization
        d = x.target if isinstance(x, Linearization) else x.domain
Martin Reinecke's avatar
Martin Reinecke committed
96
        self._check_domain_equality(self._domain, d)
97

Martin Reinecke's avatar
Martin Reinecke committed
98
    def __call__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
99
100
101
        if isinstance(x, Operator):
            return _OpChain.make((self, x))
        return self.apply(x)
Martin Reinecke's avatar
Martin Reinecke committed
102

Martin Reinecke's avatar
Martin Reinecke committed
103
104
105
106
107
108
109
110
    def ducktape(self, name):
        from .simple_linear_operators import ducktape
        return self(ducktape(self, None, name))

    def ducktape_left(self, name):
        from .simple_linear_operators import ducktape
        return ducktape(None, self, name)(self)

Martin Reinecke's avatar
Martin Reinecke committed
111
112
113
    def __repr__(self):
        return self.__class__.__name__

Martin Reinecke's avatar
Martin Reinecke committed
114

115
for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
Martin Reinecke's avatar
Martin Reinecke committed
116
          'clipped_exp', 'sin', 'cos', 'tan',
117
          'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
Martin Reinecke's avatar
Martin Reinecke committed
118
119
    def func(f):
        def func2(self):
120
            fa = _FunctionApplier(self.target, f)
Martin Reinecke's avatar
Martin Reinecke committed
121
122
123
124
125
126
127
128
            return _OpChain.make((fa, self))
        return func2
    setattr(Operator, f, func(f))


class _FunctionApplier(Operator):
    def __init__(self, domain, funcname):
        from ..sugar import makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
129
        self._domain = self._target = makeDomain(domain)
Martin Reinecke's avatar
Martin Reinecke committed
130
131
        self._funcname = funcname

Martin Reinecke's avatar
Martin Reinecke committed
132
    def apply(self, x):
133
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
134
135
136
        return getattr(x, self._funcname)()


Martin Reinecke's avatar
Martin Reinecke committed
137
138
139
140
141
142
143
144
145
146
147
148
class _Clipper(Operator):
    def __init__(self, domain, min=None, max=None):
        from ..sugar import makeDomain
        self._domain = self._target = makeDomain(domain)
        self._min = min
        self._max = max

    def apply(self, x):
        self._check_input(x)
        return x.clip(self._min, self._max)


Martin Reinecke's avatar
Martin Reinecke committed
149
150
151
152
153
154
155
156
157
158
159
class _PowerOp(Operator):
    def __init__(self, domain, power):
        from ..sugar import makeDomain
        self._domain = self._target = makeDomain(domain)
        self._power = power

    def apply(self, x):
        self._check_input(x)
        return x**self._power


Martin Reinecke's avatar
Martin Reinecke committed
160
161
162
163
164
165
166
167
168
169
class _CombinedOperator(Operator):
    def __init__(self, ops, _callingfrommake=False):
        if not _callingfrommake:
            raise NotImplementedError
        self._ops = tuple(ops)

    @classmethod
    def unpack(cls, ops, res):
        for op in ops:
            if isinstance(op, cls):
Martin Reinecke's avatar
Martin Reinecke committed
170
                res = cls.unpack(op._ops, res)
Martin Reinecke's avatar
Martin Reinecke committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            else:
                res = res + [op]
        return res

    @classmethod
    def make(cls, ops):
        res = cls.unpack(ops, [])
        if len(res) == 1:
            return res[0]
        return cls(res, _callingfrommake=True)


class _OpChain(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpChain, self).__init__(ops, _callingfrommake)
Martin Reinecke's avatar
Martin Reinecke committed
186
187
        self._domain = self._ops[-1].domain
        self._target = self._ops[0].target
Martin Reinecke's avatar
Martin Reinecke committed
188
189
190
        for i in range(1, len(self._ops)):
            if self._ops[i-1].domain != self._ops[i].target:
                raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
191

Martin Reinecke's avatar
Martin Reinecke committed
192
    def apply(self, x):
193
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
194
195
196
197
        for op in reversed(self._ops):
            x = op(x)
        return x

Philipp Arras's avatar
Philipp Arras committed
198
199
200
201
202
    def __repr__(self):
        subs = "\n".join(sub.__repr__() for sub in self._ops)
        return "_OpChain:\n" + indent(subs)


Martin Reinecke's avatar
Martin Reinecke committed
203
204
205
206
207
208
209
210
211
class _OpProd(Operator):
    def __init__(self, op1, op2):
        from ..sugar import domain_union
        self._domain = domain_union((op1.domain, op2.domain))
        self._target = op1.target
        if op1.target != op2.target:
            raise ValueError("target mismatch")
        self._op1 = op1
        self._op2 = op2
Martin Reinecke's avatar
Martin Reinecke committed
212

Martin Reinecke's avatar
Martin Reinecke committed
213
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
214
215
        from ..linearization import Linearization
        from ..sugar import makeOp
216
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
217
        lin = isinstance(x, Linearization)
218
219
220
        v = x._val if lin else x
        v1 = v.extract(self._op1.domain)
        v2 = v.extract(self._op2.domain)
Martin Reinecke's avatar
Martin Reinecke committed
221
        if not lin:
222
            return self._op1(v1) * self._op2(v2)
223
224
225
        wm = x.want_metric
        lin1 = self._op1(Linearization.make_var(v1, wm))
        lin2 = self._op2(Linearization.make_var(v2, wm))
Martin Reinecke's avatar
Martin Reinecke committed
226
227
        op = (makeOp(lin1._val)(lin2._jac))._myadd(
            makeOp(lin2._val)(lin1._jac), False)
228
        return lin1.new(lin1._val*lin2._val, op(x.jac))
Martin Reinecke's avatar
Martin Reinecke committed
229

Philipp Arras's avatar
Philipp Arras committed
230
231
232
233
234
    def __repr__(self):
        subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
        return "_OpProd:\n"+indent(subs)


Martin Reinecke's avatar
Martin Reinecke committed
235
236
class _OpSum(Operator):
    def __init__(self, op1, op2):
Philipp Arras's avatar
Philipp Arras committed
237
        from ..sugar import domain_union
Martin Reinecke's avatar
Martin Reinecke committed
238
239
240
241
        self._domain = domain_union((op1.domain, op2.domain))
        self._target = domain_union((op1.target, op2.target))
        self._op1 = op1
        self._op2 = op2
Philipp Arras's avatar
Philipp Arras committed
242
243

    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
244
        from ..linearization import Linearization
245
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
246
247
248
249
        lin = isinstance(x, Linearization)
        v = x._val if lin else x
        v1 = v.extract(self._op1.domain)
        v2 = v.extract(self._op2.domain)
Philipp Arras's avatar
Philipp Arras committed
250
        res = None
Martin Reinecke's avatar
Martin Reinecke committed
251
252
        if not lin:
            return self._op1(v1).unite(self._op2(v2))
253
254
255
        wm = x.want_metric
        lin1 = self._op1(Linearization.make_var(v1, wm))
        lin2 = self._op2(Linearization.make_var(v2, wm))
Martin Reinecke's avatar
Martin Reinecke committed
256
        op = lin1._jac._myadd(lin2._jac, False)
Martin Reinecke's avatar
bug fix  
Martin Reinecke committed
257
        res = lin1.new(lin1._val.unite(lin2._val), op(x.jac))
Martin Reinecke's avatar
Martin Reinecke committed
258
259
        if lin1._metric is not None and lin2._metric is not None:
            res = res.add_metric(lin1._metric + lin2._metric)
Philipp Arras's avatar
Philipp Arras committed
260
        return res
Philipp Arras's avatar
Philipp Arras committed
261
262
263
264

    def __repr__(self):
        subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
        return "_OpSum:\n"+indent(subs)