diagonal_operator.py 5.04 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15 16 17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
21 22 23
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
24
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
25
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
26

27 28

class DiagonalOperator(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
29
    """ NIFTy class for diagonal operators.
Theo Steininger's avatar
Theo Steininger committed
30

Martin Reinecke's avatar
Martin Reinecke committed
31
    The NIFTy DiagonalOperator class is a subclass derived from the
Theo Steininger's avatar
Theo Steininger committed
32 33
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
34

35 36
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
37
    diagonal : Field
38 39
        The diagonal entries of the operator
        (already containing volume factors).
Martin Reinecke's avatar
Martin Reinecke committed
40
    domain : tuple of DomainObjects
41 42 43 44 45
        The domain on which the Operator's input Field lives.
        If None, use the domain of "diagonal".
    spaces : tuple of int
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
46 47 48

    Attributes
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
49
    domain : DomainTuple
50
        The domain on which the Operator's input Field lives.
51

Martin Reinecke's avatar
Martin Reinecke committed
52
    NOTE: the fields given to __init__ and returned from .diagonal() are
53 54
    considered to be non-bare, i.e. during operator application, no additional
    volume factors are applied!
Martin Reinecke's avatar
Martin Reinecke committed
55

56 57 58 59 60
    See Also
    --------
    EndomorphicOperator
    """

61
    def __init__(self, diagonal, domain=None, spaces=None):
62
        super(DiagonalOperator, self).__init__()
63

Martin Reinecke's avatar
Martin Reinecke committed
64 65
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
66 67 68 69 70 71 72 73 74
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
            if diagonal.domain != self._domain:
                raise ValueError("domain mismatch")
        else:
75 76
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            if len(self._spaces) != len(diagonal.domain):
77
                raise ValueError("spaces and domain must have the same length")
Martin Reinecke's avatar
Martin Reinecke committed
78
            # if nspc==len(self.diagonal.domain),
Martin Reinecke's avatar
Martin Reinecke committed
79 80
            # we could do some optimization
            for i, j in enumerate(self._spaces):
81 82
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
83
            if self._spaces == tuple(range(len(self._domain))):
Martin Reinecke's avatar
tweak  
Martin Reinecke committed
84 85
                self._spaces = None  # shortcut

Martin Reinecke's avatar
Martin Reinecke committed
86 87
        self._diagonal = diagonal.copy()

Martin Reinecke's avatar
tweak  
Martin Reinecke committed
88 89 90 91 92
        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

Martin Reinecke's avatar
Martin Reinecke committed
93 94 95 96 97
            if self._spaces[0] == 0:
                self._ldiag = dobj.local_data(self._diagonal.val)
            else:
                self._ldiag = dobj.to_global_data(self._diagonal.val)
            locshape = dobj.local_shape(self._domain.shape, 0)
Martin Reinecke's avatar
tweak  
Martin Reinecke committed
98
            self._reshaper = [shp if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
99 100 101 102 103
                              for i, shp in enumerate(locshape)]
            self._ldiag = self._ldiag.reshape(self._reshaper)

        else:
            self._ldiag = dobj.local_data(self._diagonal.val)
104

Martin Reinecke's avatar
Martin Reinecke committed
105 106
    def apply(self, x, mode):
        self._check_input(x, mode)
107

Martin Reinecke's avatar
Martin Reinecke committed
108 109 110 111 112 113 114 115 116 117 118 119 120 121
        if mode == self.TIMES:
            return Field(x.domain, val=x.val*self._ldiag)
        elif mode == self.ADJOINT_TIMES:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val*self._ldiag)
            else:
                return Field(x.domain, val=x.val*self._ldiag.conj())
        elif mode == self.INVERSE_TIMES:
            return Field(x.domain, val=x.val/self._ldiag)
        else:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val/self._ldiag)
            else:
                return Field(x.domain, val=x.val/self._ldiag.conj())
122

123
    def diagonal(self):
Martin Reinecke's avatar
Martin Reinecke committed
124
        """ Returns the diagonal of the Operator."""
125
        return self._diagonal.copy()
126

127 128
    @property
    def domain(self):
129
        return self._domain
130

131
    @property
Martin Reinecke's avatar
Martin Reinecke committed
132
    def capability(self):
Martin Reinecke's avatar
Martin Reinecke committed
133 134 135 136 137 138 139 140 141 142
        return self._all_ops

    @property
    def inverse(self):
        return DiagonalOperator(1./self._diagonal, self._domain, self._spaces)

    @property
    def adjoint(self):
        return DiagonalOperator(self._diagonal.conjugate(), self._domain,
                                self._spaces)