diagonal_operator.py 5.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
21
22
23
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
24
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
25
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
26

27
28

class DiagonalOperator(EndomorphicOperator):
Theo Steininger's avatar
Theo Steininger committed
29
30
31
32
33
    """ NIFTY class for diagonal operators.

    The NIFTY DiagonalOperator class is a subclass derived from the
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
34

35
36
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
37
    diagonal : Field
38
39
        The diagonal entries of the operator
        (already containing volume factors).
Martin Reinecke's avatar
Martin Reinecke committed
40
    domain : tuple of DomainObjects
41
42
43
44
45
        The domain on which the Operator's input Field lives.
        If None, use the domain of "diagonal".
    spaces : tuple of int
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
46
47
48

    Attributes
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
49
    domain : DomainTuple
50
        The domain on which the Operator's input Field lives.
51

Martin Reinecke's avatar
Martin Reinecke committed
52
    NOTE: the fields given to __init__ and returned from .diagonal() are
53
54
    considered to be non-bare, i.e. during operator application, no additional
    volume factors are applied!
Martin Reinecke's avatar
Martin Reinecke committed
55

56
57
58
59
60
    See Also
    --------
    EndomorphicOperator
    """

61
    def __init__(self, diagonal, domain=None, spaces=None):
62
        super(DiagonalOperator, self).__init__()
63

Martin Reinecke's avatar
Martin Reinecke committed
64
65
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
66
67
68
69
70
71
72
73
74
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
            if diagonal.domain != self._domain:
                raise ValueError("domain mismatch")
        else:
75
76
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            if len(self._spaces) != len(diagonal.domain):
77
                raise ValueError("spaces and domain must have the same length")
Martin Reinecke's avatar
Martin Reinecke committed
78
            # if nspc==len(self.diagonal.domain),
Martin Reinecke's avatar
Martin Reinecke committed
79
80
            # we could do some optimization
            for i, j in enumerate(self._spaces):
81
82
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
83
            if self._spaces == tuple(range(len(self._domain))):
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
84
85
                self._spaces = None  # shortcut

Martin Reinecke's avatar
Martin Reinecke committed
86
87
        self._diagonal = diagonal.copy()

Martin Reinecke's avatar
tweak    
Martin Reinecke committed
88
89
90
91
92
        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

Martin Reinecke's avatar
Martin Reinecke committed
93
94
95
96
97
            if self._spaces[0] == 0:
                self._ldiag = dobj.local_data(self._diagonal.val)
            else:
                self._ldiag = dobj.to_global_data(self._diagonal.val)
            locshape = dobj.local_shape(self._domain.shape, 0)
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
98
            self._reshaper = [shp if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
99
100
101
102
103
                              for i, shp in enumerate(locshape)]
            self._ldiag = self._ldiag.reshape(self._reshaper)

        else:
            self._ldiag = dobj.local_data(self._diagonal.val)
104

Martin Reinecke's avatar
Martin Reinecke committed
105
106
    def apply(self, x, mode):
        self._check_input(x, mode)
107

Martin Reinecke's avatar
Martin Reinecke committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if mode == self.TIMES:
            return Field(x.domain, val=x.val*self._ldiag)
        elif mode == self.ADJOINT_TIMES:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val*self._ldiag)
            else:
                return Field(x.domain, val=x.val*self._ldiag.conj())
        elif mode == self.INVERSE_TIMES:
            return Field(x.domain, val=x.val/self._ldiag)
        else:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val/self._ldiag)
            else:
                return Field(x.domain, val=x.val/self._ldiag.conj())
122

123
    def diagonal(self):
Martin Reinecke's avatar
Martin Reinecke committed
124
        """ Returns the diagonal of the Operator."""
125
        return self._diagonal.copy()
126

127
128
    @property
    def domain(self):
129
        return self._domain
130

131
    @property
Martin Reinecke's avatar
Martin Reinecke committed
132
    def capability(self):
Martin Reinecke's avatar
Martin Reinecke committed
133
134
135
136
137
138
139
140
141
142
        return self._all_ops

    @property
    def inverse(self):
        return DiagonalOperator(1./self._diagonal, self._domain, self._spaces)

    @property
    def adjoint(self):
        return DiagonalOperator(self._diagonal.conjugate(), self._domain,
                                self._spaces)