operator.py 6.24 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3
from __future__ import absolute_import, division, print_function

from ..compat import *
Philipp Arras's avatar
Philipp Arras committed
4
from ..utilities import NiftyMetaBase
Martin Reinecke's avatar
Martin Reinecke committed
5 6 7 8 9 10 11


class Operator(NiftyMetaBase()):
    """Transforms values living on one domain into values living on another
    domain, and can also provide the Jacobian.
    """

Martin Reinecke's avatar
Martin Reinecke committed
12
    @property
Martin Reinecke's avatar
Martin Reinecke committed
13 14 15 16
    def domain(self):
        """DomainTuple or MultiDomain : the operator's input domain

            The domain on which the Operator's input Field lives."""
Martin Reinecke's avatar
Martin Reinecke committed
17
        return self._domain
Martin Reinecke's avatar
Martin Reinecke committed
18

Martin Reinecke's avatar
Martin Reinecke committed
19
    @property
Martin Reinecke's avatar
Martin Reinecke committed
20 21 22 23
    def target(self):
        """DomainTuple or MultiDomain : the operator's output domain

            The domain on which the Operator's output Field lives."""
Martin Reinecke's avatar
Martin Reinecke committed
24
        return self._target
Martin Reinecke's avatar
Martin Reinecke committed
25

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
26 27 28 29 30 31 32 33 34 35
    def scale(self, factor):
        if factor == 1:
            return self
        from .scaling_operator import ScalingOperator
        return ScalingOperator(factor, self.target)(self)

    def conjugate(self):
        from .simple_linear_operators import ConjugationOperator
        return ConjugationOperator(self.target)(self)

Martin Reinecke's avatar
Martin Reinecke committed
36 37 38 39 40
    @property
    def real(self):
        from .simple_linear_operators import Realizer
        return Realizer(self.target)(self)

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
41 42 43
    def __neg__(self):
        return self.scale(-1)

Martin Reinecke's avatar
Martin Reinecke committed
44 45 46
    def __matmul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
47
        return _OpChain.make((self, x))
Martin Reinecke's avatar
Martin Reinecke committed
48

Martin Reinecke's avatar
Martin Reinecke committed
49 50 51
    def __mul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
52
        return _OpProd(self, x)
Martin Reinecke's avatar
Martin Reinecke committed
53

Philipp Arras's avatar
Philipp Arras committed
54 55 56
    def __add__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
57
        return _OpSum(self, x)
Philipp Arras's avatar
Philipp Arras committed
58

Martin Reinecke's avatar
Martin Reinecke committed
59 60
    def apply(self, x):
        raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
61

Philipp Arras's avatar
Philipp Arras committed
62
    def force(self, x):
Philipp Arras's avatar
Philipp Arras committed
63
        """Extract correct subset of domain of x and apply operator."""
Philipp Arras's avatar
Philipp Arras committed
64 65
        return self.apply(x.extract(self.domain))

66 67 68 69
    def _check_input(self, x):
        from ..linearization import Linearization
        d = x.target if isinstance(x, Linearization) else x.domain
        if self._domain != d:
70 71 72
            s = "The operator's and field's domains don't match."
            from ..domain_tuple import DomainTuple
            from ..multi_domain import MultiDomain
Philipp Arras's avatar
Philipp Arras committed
73 74 75
            if not isinstance(self._dom(mode),
                              [DomainTuple, MultiDomain]) or isinstance(
                                  x.domain, [DomainTuple, MultiDomain]):
76 77
                s += " One of the domains is neither a `DomainTuple` nor a `MultiDomain`."
            raise ValueError(s)
78

Martin Reinecke's avatar
Martin Reinecke committed
79
    def __call__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
80 81 82
        if isinstance(x, Operator):
            return _OpChain.make((self, x))
        return self.apply(x)
Martin Reinecke's avatar
Martin Reinecke committed
83

Martin Reinecke's avatar
Martin Reinecke committed
84 85 86
    def __repr__(self):
        return self.__class__.__name__

Martin Reinecke's avatar
Martin Reinecke committed
87

Martin Reinecke's avatar
Martin Reinecke committed
88 89 90
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
    def func(f):
        def func2(self):
91
            fa = _FunctionApplier(self.target, f)
Martin Reinecke's avatar
Martin Reinecke committed
92 93 94 95 96 97 98 99
            return _OpChain.make((fa, self))
        return func2
    setattr(Operator, f, func(f))


class _FunctionApplier(Operator):
    def __init__(self, domain, funcname):
        from ..sugar import makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
100
        self._domain = self._target = makeDomain(domain)
Martin Reinecke's avatar
Martin Reinecke committed
101 102
        self._funcname = funcname

Martin Reinecke's avatar
Martin Reinecke committed
103
    def apply(self, x):
104
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
105 106 107
        return getattr(x, self._funcname)()


Martin Reinecke's avatar
Martin Reinecke committed
108 109 110 111 112 113 114 115 116 117
class _CombinedOperator(Operator):
    def __init__(self, ops, _callingfrommake=False):
        if not _callingfrommake:
            raise NotImplementedError
        self._ops = tuple(ops)

    @classmethod
    def unpack(cls, ops, res):
        for op in ops:
            if isinstance(op, cls):
Martin Reinecke's avatar
Martin Reinecke committed
118
                res = cls.unpack(op._ops, res)
Martin Reinecke's avatar
Martin Reinecke committed
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
            else:
                res = res + [op]
        return res

    @classmethod
    def make(cls, ops):
        res = cls.unpack(ops, [])
        if len(res) == 1:
            return res[0]
        return cls(res, _callingfrommake=True)


class _OpChain(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpChain, self).__init__(ops, _callingfrommake)
Martin Reinecke's avatar
Martin Reinecke committed
134 135
        self._domain = self._ops[-1].domain
        self._target = self._ops[0].target
Martin Reinecke's avatar
Martin Reinecke committed
136 137 138
        for i in range(1, len(self._ops)):
            if self._ops[i-1].domain != self._ops[i].target:
                raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
139

Martin Reinecke's avatar
Martin Reinecke committed
140
    def apply(self, x):
141
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
142 143 144 145 146
        for op in reversed(self._ops):
            x = op(x)
        return x


Martin Reinecke's avatar
Martin Reinecke committed
147 148 149 150 151 152 153 154 155
class _OpProd(Operator):
    def __init__(self, op1, op2):
        from ..sugar import domain_union
        self._domain = domain_union((op1.domain, op2.domain))
        self._target = op1.target
        if op1.target != op2.target:
            raise ValueError("target mismatch")
        self._op1 = op1
        self._op2 = op2
Martin Reinecke's avatar
Martin Reinecke committed
156

Martin Reinecke's avatar
Martin Reinecke committed
157
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
158 159
        from ..linearization import Linearization
        from ..sugar import makeOp
160
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
161
        lin = isinstance(x, Linearization)
162 163 164
        v = x._val if lin else x
        v1 = v.extract(self._op1.domain)
        v2 = v.extract(self._op2.domain)
Martin Reinecke's avatar
Martin Reinecke committed
165
        if not lin:
166
            return self._op1(v1) * self._op2(v2)
167 168 169
        wm = x.want_metric
        lin1 = self._op1(Linearization.make_var(v1, wm))
        lin2 = self._op2(Linearization.make_var(v2, wm))
Martin Reinecke's avatar
Martin Reinecke committed
170 171
        op = (makeOp(lin1._val)(lin2._jac))._myadd(
            makeOp(lin2._val)(lin1._jac), False)
172
        return lin1.new(lin1._val*lin2._val, op(x.jac))
Martin Reinecke's avatar
Martin Reinecke committed
173 174


Martin Reinecke's avatar
Martin Reinecke committed
175 176
class _OpSum(Operator):
    def __init__(self, op1, op2):
Philipp Arras's avatar
Philipp Arras committed
177
        from ..sugar import domain_union
Martin Reinecke's avatar
Martin Reinecke committed
178 179 180 181
        self._domain = domain_union((op1.domain, op2.domain))
        self._target = domain_union((op1.target, op2.target))
        self._op1 = op1
        self._op2 = op2
Philipp Arras's avatar
Philipp Arras committed
182 183

    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
184
        from ..linearization import Linearization
185
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
186 187 188 189
        lin = isinstance(x, Linearization)
        v = x._val if lin else x
        v1 = v.extract(self._op1.domain)
        v2 = v.extract(self._op2.domain)
Philipp Arras's avatar
Philipp Arras committed
190
        res = None
Martin Reinecke's avatar
Martin Reinecke committed
191 192
        if not lin:
            return self._op1(v1).unite(self._op2(v2))
193 194 195
        wm = x.want_metric
        lin1 = self._op1(Linearization.make_var(v1, wm))
        lin2 = self._op2(Linearization.make_var(v2, wm))
Martin Reinecke's avatar
Martin Reinecke committed
196
        op = lin1._jac._myadd(lin2._jac, False)
197
        res = lin1.new(lin1._val+lin2._val, op(x.jac))
Martin Reinecke's avatar
Martin Reinecke committed
198 199
        if lin1._metric is not None and lin2._metric is not None:
            res = res.add_metric(lin1._metric + lin2._metric)
Philipp Arras's avatar
Philipp Arras committed
200
        return res