diagonal_operator.py 7.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19
20
21
22
23

import numpy as np

from d2o import distributed_data_object,\
                STRATEGIES as DISTRIBUTION_STRATEGIES

24
from nifty.config import nifty_configuration as gc
25
26
27
28
29
30
31
32
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator


class DiagonalOperator(EndomorphicOperator):

    # ---Overwritten properties and methods---

33
    def __init__(self, domain=(), implemented=True,
34
35
                 diagonal=None, bare=False, copy=True,
                 distribution_strategy=None):
36
        self._domain = self._parse_domain(domain)
37

38
39
        self._implemented = bool(implemented)

40
        if distribution_strategy is None:
41
            if isinstance(diagonal, distributed_data_object):
42
                distribution_strategy = diagonal.distribution_strategy
43
            elif isinstance(diagonal, Field):
44
                distribution_strategy = diagonal.distribution_strategy
45

46
        self._distribution_strategy = self._parse_distribution_strategy(
47
48
                               distribution_strategy=distribution_strategy,
                               val=diagonal)
49
50
51

        self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)

52
53
    def _times(self, x, spaces):
        return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
54

55
56
    def _adjoint_times(self, x, spaces):
        return self._times_helper(x, spaces,
57
                                  operation=lambda z: z.adjoint().__mul__)
58

59
60
    def _inverse_times(self, x, spaces):
        return self._times_helper(x, spaces, operation=lambda z: z.__rdiv__)
61

62
63
    def _adjoint_inverse_times(self, x, spaces):
        return self._times_helper(x, spaces,
64
                                  operation=lambda z: z.adjoint().__rdiv__)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def diagonal(self, bare=False, copy=True):
        if bare:
            diagonal = self._diagonal.weight(power=-1)
        elif copy:
            diagonal = self._diagonal.copy()
        else:
            diagonal = self._diagonal
        return diagonal

    def inverse_diagonal(self, bare=False):
        return 1/self.diagonal(bare=bare, copy=False)

    def trace(self, bare=False):
        return self.diagonal(bare=bare, copy=False).sum()

    def inverse_trace(self, bare=False):
        return self.inverse_diagonal(bare=bare, copy=False).sum()

    def trace_log(self):
        log_diagonal = self.diagonal(copy=False).apply_scalar_function(np.log)
        return log_diagonal.sum()

    def determinant(self):
        return self.diagonal(copy=False).val.prod()

    def inverse_determinant(self):
        return 1/self.determinant()

    def log_determinant(self):
        return np.log(self.determinant())

    # ---Mandatory properties and methods---

99
100
101
102
    @property
    def domain(self):
        return self._domain

103
104
105
106
    @property
    def implemented(self):
        return self._implemented

107
    @property
Martin Reinecke's avatar
Martin Reinecke committed
108
109
110
111
    def self_adjoint(self):
        if self._self_adjoint is None:
            self._self_adjoint = (self._diagonal.val.imag == 0).all()
        return self._self_adjoint
112
113
114

    @property
    def unitary(self):
115
116
117
        if self._unitary is None:
            self._unitary = (self._diagonal.val *
                             self._diagonal.val.conjugate() == 1).all()
118
119
120
121
122
        return self._unitary

    # ---Added properties and methods---

    @property
123
124
    def distribution_strategy(self):
        return self._distribution_strategy
125

126
127
    def _parse_distribution_strategy(self, distribution_strategy, val):
        if distribution_strategy is None:
128
            if isinstance(val, distributed_data_object):
129
                distribution_strategy = val.distribution_strategy
130
            elif isinstance(val, Field):
131
                distribution_strategy = val.distribution_strategy
132
            else:
133
                self.logger.info("Datamodel set to default!")
134
135
                distribution_strategy = gc['default_distribution_strategy']
        elif distribution_strategy not in DISTRIBUTION_STRATEGIES['all']:
136
137
            raise ValueError(
                    "Invalid distribution_strategy!")
138
        return distribution_strategy
139
140
141
142
143

    def set_diagonal(self, diagonal, bare=False, copy=True):
        # use the casting functionality from Field to process `diagonal`
        f = Field(domain=self.domain,
                  val=diagonal,
144
                  distribution_strategy=self.distribution_strategy,
145
146
                  copy=copy)

147
148
149
150
151
152
153
154
155
156
        # weight if the given values were `bare` and `implemented` is True
        # do inverse weightening if the other way around
        if bare and self.implemented:
            # If `copy` is True, we won't change external data by weightening
            # Otherwise, inplace weightening would change the external field
            f.weight(inplace=copy)
        elif not bare and not self.implemented:
            # If `copy` is True, we won't change external data by weightening
            # Otherwise, inplace weightening would change the external field
            f.weight(inplace=copy, power=-1)
157

Martin Reinecke's avatar
Martin Reinecke committed
158
159
        # Reset the self_adjoint property:
        self._self_adjoint = None
160

161
162
        # Reset the unitarity property
        self._unitary = None
163
164
165

        # store the diagonal-field
        self._diagonal = f
166

167
168
    def _times_helper(self, x, spaces, operation):
        # if the domain matches directly
169
        # -> multiply the fields directly
170
        if x.domain == self.domain:
171
172
173
174
175
176
177
            # here the actual multiplication takes place
            return operation(self.diagonal(copy=False))(x)

        # if the distribution_strategy of self is sub-slice compatible to
        # the one of x, reshape the local data of self and apply it directly
        active_axes = []
        if spaces is None:
178
            active_axes = range(len(x.shape))
179
180
181
182
183
184
185
186
187
188
        else:
            for space_index in spaces:
                active_axes += x.domain_axes[space_index]

        axes_local_distribution_strategy = \
            x.val.get_axes_local_distribution_strategy(active_axes)
        if axes_local_distribution_strategy == self.distribution_strategy:
            local_diagonal = self._diagonal.val.get_local_data(copy=False)
        else:
            # create an array that is sub-slice compatible
189
190
            self.logger.warn("The input field is not sub-slice compatible to "
                             "the distribution strategy of the operator.")
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            redistr_diagonal_val = self._diagonal.val.copy(
                distribution_strategy=axes_local_distribution_strategy)
            local_diagonal = redistr_diagonal_val.get_local_data(copy=False)

        reshaper = [x.shape[i] if i in active_axes else 1
                    for i in xrange(len(x.shape))]
        reshaped_local_diagonal = np.reshape(local_diagonal, reshaper)

        # here the actual multiplication takes place
        local_result = operation(reshaped_local_diagonal)(
                           x.val.get_local_data(copy=False))

        result_field = x.copy_empty(dtype=local_result.dtype)
        result_field.val.set_local_data(local_result, copy=False)
        return result_field