operator.py 3.79 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3
from __future__ import absolute_import, division, print_function

from ..compat import *
Martin Reinecke's avatar
tmp  
Martin Reinecke committed
4 5
from ..utilities import NiftyMetaBase, my_product
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
6 7 8 9 10 11 12


class Operator(NiftyMetaBase()):
    """Transforms values living on one domain into values living on another
    domain, and can also provide the Jacobian.
    """

Martin Reinecke's avatar
Martin Reinecke committed
13
    @property
Martin Reinecke's avatar
Martin Reinecke committed
14 15 16 17
    def domain(self):
        """DomainTuple or MultiDomain : the operator's input domain

            The domain on which the Operator's input Field lives."""
Martin Reinecke's avatar
Martin Reinecke committed
18
        return self._domain
Martin Reinecke's avatar
Martin Reinecke committed
19

Martin Reinecke's avatar
Martin Reinecke committed
20
    @property
Martin Reinecke's avatar
Martin Reinecke committed
21 22 23 24
    def target(self):
        """DomainTuple or MultiDomain : the operator's output domain

            The domain on which the Operator's output Field lives."""
Martin Reinecke's avatar
Martin Reinecke committed
25
        return self._target
Martin Reinecke's avatar
Martin Reinecke committed
26

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
27 28 29 30 31 32 33 34 35 36
    def scale(self, factor):
        if factor == 1:
            return self
        from .scaling_operator import ScalingOperator
        return ScalingOperator(factor, self.target)(self)

    def conjugate(self):
        from .simple_linear_operators import ConjugationOperator
        return ConjugationOperator(self.target)(self)

Martin Reinecke's avatar
Martin Reinecke committed
37 38 39 40 41
    @property
    def real(self):
        from .simple_linear_operators import Realizer
        return Realizer(self.target)(self)

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
42 43 44
    def __neg__(self):
        return self.scale(-1)

Martin Reinecke's avatar
Martin Reinecke committed
45 46 47
    def __matmul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
Martin Reinecke's avatar
Martin Reinecke committed
48
        return _OpChain.make((self, x))
Martin Reinecke's avatar
Martin Reinecke committed
49

Martin Reinecke's avatar
Martin Reinecke committed
50 51 52 53 54
    def __mul__(self, x):
        if not isinstance(x, Operator):
            return NotImplemented
        return _OpProd.make((self, x))

Martin Reinecke's avatar
Martin Reinecke committed
55 56
    def apply(self, x):
        raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
57 58

    def __call__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
59 60 61
       if isinstance(x, Operator):
           return _OpChain.make((self, x))
       return self.apply(x)
Martin Reinecke's avatar
Martin Reinecke committed
62 63


Martin Reinecke's avatar
Martin Reinecke committed
64 65 66
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
    def func(f):
        def func2(self):
67
            fa = _FunctionApplier(self.target, f)
Martin Reinecke's avatar
Martin Reinecke committed
68 69 70 71 72 73 74 75
            return _OpChain.make((fa, self))
        return func2
    setattr(Operator, f, func(f))


class _FunctionApplier(Operator):
    def __init__(self, domain, funcname):
        from ..sugar import makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
76
        self._domain = self._target = makeDomain(domain)
Martin Reinecke's avatar
Martin Reinecke committed
77 78
        self._funcname = funcname

Martin Reinecke's avatar
Martin Reinecke committed
79
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
80 81 82
        return getattr(x, self._funcname)()


Martin Reinecke's avatar
Martin Reinecke committed
83 84 85 86 87 88 89 90 91 92
class _CombinedOperator(Operator):
    def __init__(self, ops, _callingfrommake=False):
        if not _callingfrommake:
            raise NotImplementedError
        self._ops = tuple(ops)

    @classmethod
    def unpack(cls, ops, res):
        for op in ops:
            if isinstance(op, cls):
Martin Reinecke's avatar
Martin Reinecke committed
93
                res = cls.unpack(op._ops, res)
Martin Reinecke's avatar
Martin Reinecke committed
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
            else:
                res = res + [op]
        return res

    @classmethod
    def make(cls, ops):
        res = cls.unpack(ops, [])
        if len(res) == 1:
            return res[0]
        return cls(res, _callingfrommake=True)


class _OpChain(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpChain, self).__init__(ops, _callingfrommake)
Martin Reinecke's avatar
Martin Reinecke committed
109 110
        self._domain = self._ops[-1].domain
        self._target = self._ops[0].target
Martin Reinecke's avatar
Martin Reinecke committed
111

Martin Reinecke's avatar
Martin Reinecke committed
112
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
113 114 115 116 117 118 119 120
        for op in reversed(self._ops):
            x = op(x)
        return x


class _OpProd(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpProd, self).__init__(ops, _callingfrommake)
Martin Reinecke's avatar
Martin Reinecke committed
121 122
        self._domain = self._ops[0].domain
        self._target = self._ops[0].target
Martin Reinecke's avatar
Martin Reinecke committed
123

Martin Reinecke's avatar
Martin Reinecke committed
124
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
125
        return my_product(map(lambda op: op(x), self._ops))
Martin Reinecke's avatar
Martin Reinecke committed
126 127 128 129 130 131 132 133


class _OpSum(_CombinedOperator):
    def __init__(self, ops, _callingfrommake=False):
        super(_OpSum, self).__init__(ops, _callingfrommake)
        self._domain = domain_union([op.domain for op in self._ops])
        self._target = domain_union([op.target for op in self._ops])

Martin Reinecke's avatar
Martin Reinecke committed
134
    def apply(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
135
        raise NotImplementedError