composed_operator.py 5.28 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18 19 20 21 22

from nifty.operators.linear_operator import LinearOperator


class ComposedOperator(LinearOperator):
Theo Steininger's avatar
Theo Steininger committed
23
    """ NIFTY class for composed operators.
24

Theo Steininger's avatar
Theo Steininger committed
25
    The  NIFTY composed operator class combines multiple linear operators.
26 27 28

    Parameters
    ----------
Theo Steininger's avatar
Theo Steininger committed
29 30
    operators : tuple of NIFTy Operators
        The tuple of LinearOperators.
31 32 33 34
    default_spaces : tuple of ints *optional*
        Defines on which space(s) of a given field the Operator acts by
        default (default: None)

35 36 37

    Attributes
    ----------
Theo Steininger's avatar
Theo Steininger committed
38
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
39
        The NIFTy.space in which the operator is defined.
Theo Steininger's avatar
Theo Steininger committed
40
    target : tuple of DomainObjects, i.e. Spaces and FieldTypes
41
        The NIFTy.space in which the outcome of the operator lives
42 43
    unitary : boolean
        Indicates whether the Operator is unitary or not.
44 45 46 47 48

    Raises
    ------
    TypeError
        Raised if
Theo Steininger's avatar
Theo Steininger committed
49 50
            * an element of the operator list is not an instance of the
              LinearOperator-baseclass.
51 52 53

    Notes
    -----
Theo Steininger's avatar
Theo Steininger committed
54 55
    Very usefull in case one has to transform a Field living over a product
    space (see example below).
56 57 58

    Examples
    --------
Theo Steininger's avatar
Theo Steininger committed
59 60
    Minimal example of transforming a Field living on two domains into its
    harmonic space.
61 62 63 64 65

    >>> x1 = RGSpace(5)
    >>> x2 = RGSpace(10)
    >>> k1 = RGRGTransformation.get_codomain(x1)
    >>> k2 = RGRGTransformation.get_codomain(x2)
Theo Steininger's avatar
Theo Steininger committed
66 67 68 69
    >>> FFT1 = FFTOperator(domain=x1, target=k1,
                           domain_dtype=np.float64, target_dtype=np.complex128)
    >>> FFT2 = FFTOperator(domain=x2, target=k2,
                           domain_dtype=np.float64, target_dtype=np.complex128)
70 71 72 73 74 75 76 77 78 79 80 81
    >>> FFT = ComposedOperator((FFT1, FFT2)
    >>> f = Field.from_random('normal', domain=(x1,x2))
    >>> FFT.times(f)

    See Also
    --------
    EndomorphicOperator, ProjectionOperator,
    DiagonalOperator, SmoothingOperator, ResponseOperator,
    PropagatorOperator, ComposedOperator

    """

82
    # ---Overwritten properties and methods---
83 84 85
    def __init__(self, operators, default_spaces=None):
        super(ComposedOperator, self).__init__(default_spaces)

86 87 88 89 90 91 92
        self._operator_store = ()
        for op in operators:
            if not isinstance(op, LinearOperator):
                raise TypeError("The elements of the operator list must be"
                                "instances of the LinearOperator-baseclass")
            self._operator_store += (op,)

93
    def _check_input_compatibility(self, x, spaces, inverse=False):
94 95 96 97 98
        """
        The input check must be disabled for the ComposedOperator, since it
        is not easily forecasteable what the output of an operator-call
        will look like.
        """
99 100
        if spaces is None:
            spaces = self.default_spaces
101
        return spaces
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123

    # ---Mandatory properties and methods---
    @property
    def domain(self):
        if not hasattr(self, '_domain'):
            self._domain = ()
            for op in self._operator_store:
                self._domain += op.domain
        return self._domain

    @property
    def target(self):
        if not hasattr(self, '_target'):
            self._target = ()
            for op in self._operator_store:
                self._target += op.target
        return self._target

    @property
    def unitary(self):
        return False

124 125
    def _times(self, x, spaces):
        return self._times_helper(x, spaces, func='times')
126

127 128
    def _adjoint_times(self, x, spaces):
        return self._inverse_times_helper(x, spaces, func='adjoint_times')
129

130 131
    def _inverse_times(self, x, spaces):
        return self._inverse_times_helper(x, spaces, func='inverse_times')
132

133 134
    def _adjoint_inverse_times(self, x, spaces):
        return self._times_helper(x, spaces, func='adjoint_inverse_times')
135

136
    def _times_helper(self, x, spaces, func):
137 138 139 140 141 142 143
        space_index = 0
        if spaces is None:
            spaces = range(len(self.domain))
        for op in self._operator_store:
            active_spaces = spaces[space_index:space_index+len(op.domain)]
            space_index += len(op.domain)

144
            x = getattr(op, func)(x, spaces=active_spaces)
145 146
        return x

147
    def _inverse_times_helper(self, x, spaces, func):
148 149 150 151 152 153 154
        space_index = 0
        if spaces is None:
            spaces = range(len(self.target))[::-1]
        for op in reversed(self._operator_store):
            active_spaces = spaces[space_index:space_index+len(op.target)]
            space_index += len(op.target)

155
            x = getattr(op, func)(x, spaces=active_spaces)
156
        return x