operator.py 4.72 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
from __future__ import absolute_import, division, print_function

from ..compat import *
Martin Reinecke's avatar
tmp    
Martin Reinecke committed
4
5
from ..utilities import NiftyMetaBase, my_product
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
6
7
8
9
10
11
12


class Operator(NiftyMetaBase()):
    """Transforms values living on one domain into values living on another
    domain, and can also provide the Jacobian.
    """

Martin Reinecke's avatar
Martin Reinecke committed
13
    @property
Martin Reinecke's avatar
Martin Reinecke committed
14
15
16
17
    def domain(self):
        """DomainTuple or MultiDomain : the operator's input domain

            The domain on which the Operator's input Field lives."""
18
        raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
19

Martin Reinecke's avatar
Martin Reinecke committed
20
    @property
Martin Reinecke's avatar
Martin Reinecke committed
21
22
23
24
    def target(self):
        """DomainTuple or MultiDomain : the operator's output domain

            The domain on which the Operator's output Field lives."""
25
        raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
26
27
28
29

    def __matmul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
30
        return _OpChain.make((self, x))
Martin Reinecke's avatar
Martin Reinecke committed
31

Martin Reinecke's avatar
Martin Reinecke committed
32
33
34
35
36
    def __mul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
        return _OpProd.make((self, x))

Martin Reinecke's avatar
Martin Reinecke committed
37
38
    def apply(self, x):
        raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
39
40

    def __call__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
41
42
43
       if isinstance(x, Operator):
           return _OpChain.make((self, x))
       return self.apply(x)
Martin Reinecke's avatar
Martin Reinecke committed
44
45


Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
    def func(f):
        def func2(self):
49
            fa = _FunctionApplier(self.target, f)
Martin Reinecke's avatar
Martin Reinecke committed
50
51
52
53
54
            return _OpChain.make((fa, self))
        return func2
    setattr(Operator, f, func(f))


Martin Reinecke's avatar
tmp    
Martin Reinecke committed
55
56
57
58
59
60
61
class EnergyOperator(Operator):
    _target = DomainTuple.scalar_domain()

    @property
    def target(self):
        return EnergyOperator._target

Martin Reinecke's avatar
Martin Reinecke committed
62
63
64
class _FunctionApplier(Operator):
    def __init__(self, domain, funcname):
        from ..sugar import makeDomain
65
        self._domain = makeDomain(domain)
Martin Reinecke's avatar
Martin Reinecke committed
66
67
        self._funcname = funcname

68
69
70
71
72
73
74
75
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._domain

Martin Reinecke's avatar
Martin Reinecke committed
76
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
77
78
79
        return getattr(x, self._funcname)()


Martin Reinecke's avatar
Martin Reinecke committed
80
81
82
83
84
85
86
87
88
89
class _CombinedOperator(Operator):
    def __init__(self, ops, _callingfrommake=False):
        if not _callingfrommake:
            raise NotImplementedError
        self._ops = tuple(ops)

    @classmethod
    def unpack(cls, ops, res):
        for op in ops:
            if isinstance(op, cls):
Martin Reinecke's avatar
Martin Reinecke committed
90
                res = cls.unpack(op._ops, res)
Martin Reinecke's avatar
Martin Reinecke committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            else:
                res = res + [op]
        return res

    @classmethod
    def make(cls, ops):
        res = cls.unpack(ops, [])
        if len(res) == 1:
            return res[0]
        return cls(res, _callingfrommake=True)


class _OpChain(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpChain, self).__init__(ops, _callingfrommake)
106
107
108
109
110
111
112
113

    @property
    def domain(self):
        return self._ops[-1].domain

    @property
    def target(self):
        return self._ops[0].target
Martin Reinecke's avatar
Martin Reinecke committed
114

Martin Reinecke's avatar
Martin Reinecke committed
115
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
116
117
118
119
120
121
122
123
        for op in reversed(self._ops):
            x = op(x)
        return x


class _OpProd(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpProd, self).__init__(ops, _callingfrommake)
124
125
126
127
128
129
130
131

    @property
    def domain(self):
        return self._ops[0].domain

    @property
    def target(self):
        return self._ops[0].target
Martin Reinecke's avatar
Martin Reinecke committed
132

Martin Reinecke's avatar
Martin Reinecke committed
133
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
134
        return my_product(map(lambda op: op(x), self._ops))
Martin Reinecke's avatar
Martin Reinecke committed
135
136
137
138
139
140
141
142


class _OpSum(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpSum, self).__init__(ops, _callingfrommake)
        self._domain = domain_union([op.domain for op in self._ops])
        self._target = domain_union([op.target for op in self._ops])

143
144
145
146
147
148
149
150
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

Martin Reinecke's avatar
Martin Reinecke committed
151
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
152
        raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
153
154


Martin Reinecke's avatar
tmp    
Martin Reinecke committed
155
class SquaredNormOperator(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
156
157
158
159
160
161
162
163
    def __init__(self, domain):
        super(SquaredNormOperator, self).__init__()
        self._domain = domain

    @property
    def domain(self):
        return self._domain

Martin Reinecke's avatar
tmp    
Martin Reinecke committed
164
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
165
166
167
        return Field(self._target, x.vdot(x))


Martin Reinecke's avatar
tmp    
Martin Reinecke committed
168
class QuadraticFormOperator(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
169
170
171
172
173
174
175
176
177
178
179
    def __init__(self, op):
        from .endomorphic_operator import EndomorphicOperator
        super(QuadraticFormOperator, self).__init__()
        if not isinstance(op, EndomorphicOperator):
            raise TypeError("op must be an EndomorphicOperator")
        self._op = op

    @property
    def domain(self):
        return self._op.domain

Martin Reinecke's avatar
Martin Reinecke committed
180
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
181
182
183
184
185
        if isinstance(x, Linearization):
            jac = self._op(x)
            val = Field(self._target, 0.5 * x.vdot(jac))
            return Linearization(val, jac)
        return Field(self._target, 0.5 * x.vdot(self._op(x)))