diagonal_operator.py 5.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20
import numpy as np
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
21
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
22
23
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
24
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
25
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
26

27
28

class DiagonalOperator(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
29
    """ NIFTy class for diagonal operators.
Theo Steininger's avatar
Theo Steininger committed
30

Martin Reinecke's avatar
Martin Reinecke committed
31
    The NIFTy DiagonalOperator class is a subclass derived from the
Theo Steininger's avatar
Theo Steininger committed
32
33
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
34

35
36
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
37
    diagonal : Field
Martin Reinecke's avatar
docs    
Martin Reinecke committed
38
39
        The diagonal entries of the operator.
    domain : Domain, tuple of Domain or DomainTuple, optional
40
41
        The domain on which the Operator's input Field lives.
        If None, use the domain of "diagonal".
Martin Reinecke's avatar
docs    
Martin Reinecke committed
42
    spaces : int or tuple of int, optional
43
44
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
Martin Reinecke's avatar
Martin Reinecke committed
45
46
47
48
49
50
51
52
53
54

    Notes
    -----
    Formally, this operator always supports all operation modes (times,
    adjoint_times, inverse_times and inverse_adjoint_times), even if there
    are diagonal elements with value 0 or infinity. It is the user's
    responsibility to apply the operator only in appropriate ways (e.g. call
    inverse_times only if there are no zeros on the diagonal).

    This shortcoming will hopefully be fixed in the future.
55
56
    """

57
    def __init__(self, diagonal, domain=None, spaces=None, _ldiag=None):
58
        super(DiagonalOperator, self).__init__()
59

60
61
62
63
64
65
        if _ldiag is not None:  # very special hack
            self._ldiag = _ldiag
            self._domain = domain
            self._spaces = spaces
            return

Martin Reinecke's avatar
Martin Reinecke committed
66
67
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
68
69
70
71
72
73
74
75
76
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
            if diagonal.domain != self._domain:
                raise ValueError("domain mismatch")
        else:
77
78
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            if len(self._spaces) != len(diagonal.domain):
79
                raise ValueError("spaces and domain must have the same length")
Martin Reinecke's avatar
Martin Reinecke committed
80
            # if nspc==len(self.diagonal.domain),
Martin Reinecke's avatar
Martin Reinecke committed
81
82
            # we could do some optimization
            for i, j in enumerate(self._spaces):
83
84
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
85
            if self._spaces == tuple(range(len(self._domain))):
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
86
87
                self._spaces = None  # shortcut

Martin Reinecke's avatar
Martin Reinecke committed
88
        self._diagonal = diagonal.lock()
Martin Reinecke's avatar
Martin Reinecke committed
89

Martin Reinecke's avatar
tweak    
Martin Reinecke committed
90
91
92
93
94
        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

Martin Reinecke's avatar
Martin Reinecke committed
95
            if self._spaces[0] == 0:
Martin Reinecke's avatar
Martin Reinecke committed
96
                self._ldiag = self._diagonal.local_data
Martin Reinecke's avatar
Martin Reinecke committed
97
            else:
Martin Reinecke's avatar
Martin Reinecke committed
98
                self._ldiag = self._diagonal.to_global_data()
Martin Reinecke's avatar
Martin Reinecke committed
99
            locshape = dobj.local_shape(self._domain.shape, 0)
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
100
            self._reshaper = [shp if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
101
102
103
104
                              for i, shp in enumerate(locshape)]
            self._ldiag = self._ldiag.reshape(self._reshaper)

        else:
Martin Reinecke's avatar
Martin Reinecke committed
105
            self._ldiag = self._diagonal.local_data
106

Martin Reinecke's avatar
Martin Reinecke committed
107
108
    def apply(self, x, mode):
        self._check_input(x, mode)
109

Martin Reinecke's avatar
Martin Reinecke committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        if mode == self.TIMES:
            return Field(x.domain, val=x.val*self._ldiag)
        elif mode == self.ADJOINT_TIMES:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val*self._ldiag)
            else:
                return Field(x.domain, val=x.val*self._ldiag.conj())
        elif mode == self.INVERSE_TIMES:
            return Field(x.domain, val=x.val/self._ldiag)
        else:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val/self._ldiag)
            else:
                return Field(x.domain, val=x.val/self._ldiag.conj())
124

125
126
    @property
    def domain(self):
127
        return self._domain
128

129
    @property
Martin Reinecke's avatar
Martin Reinecke committed
130
    def capability(self):
Martin Reinecke's avatar
Martin Reinecke committed
131
132
133
134
        return self._all_ops

    @property
    def inverse(self):
135
136
        return DiagonalOperator(None, self._domain, self._spaces,
                                1./self._ldiag)
Martin Reinecke's avatar
Martin Reinecke committed
137
138
139

    @property
    def adjoint(self):
140
141
        return DiagonalOperator(None, self._domain,
                                self._spaces, self._ldiag.conjugate())
142

143
144
145
146
147
148
149
150
    def process_sample(self, sample):
        if np.issubdtype(self._ldiag.dtype, np.complexfloating):
            raise ValueError("cannot draw sample from complex-valued operator")

        res = Field.empty_like(sample)
        res.local_data[()] = sample.local_data * np.sqrt(self._ldiag)
        return res

151
    def draw_sample(self, dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
152
153
        if np.issubdtype(self._ldiag.dtype, np.complexfloating):
            raise ValueError("cannot draw sample from complex-valued operator")
154

155
156
        res = Field.from_random(random_type="normal", domain=self._domain,
                                dtype=dtype)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
157
        res.local_data[()] *= np.sqrt(self._ldiag)
158
        return res