diagonal_operator.py 7.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# -*- coding: utf-8 -*-

import numpy as np

from d2o import distributed_data_object,\
                STRATEGIES as DISTRIBUTION_STRATEGIES

from nifty.config import about,\
                         nifty_configuration as gc
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator


class DiagonalOperator(EndomorphicOperator):

    # ---Overwritten properties and methods---

18
    def __init__(self, domain=(), field_type=(), implemented=True,
19
20
                 diagonal=None, bare=False, copy=True,
                 distribution_strategy=None):
21
        super(DiagonalOperator, self).__init__(domain=domain,
22
                                               field_type=field_type)
23

24
25
        self._implemented = bool(implemented)

26
        if distribution_strategy is None:
27
            if isinstance(diagonal, distributed_data_object):
28
                distribution_strategy = diagonal.distribution_strategy
29
            elif isinstance(diagonal, Field):
30
                distribution_strategy = diagonal.distribution_strategy
31

32
        self._distribution_strategy = self._parse_distribution_strategy(
33
34
                               distribution_strategy=distribution_strategy,
                               val=diagonal)
35
36
37
38

        self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)

    def _times(self, x, spaces, types):
39
40
        return self._times_helper(x, spaces, types,
                                  operation=lambda z: z.__mul__)
41
42

    def _adjoint_times(self, x, spaces, types):
43
44
        return self._times_helper(x, spaces, types,
                                  operation=lambda z: z.adjoint().__mul__)
45
46

    def _inverse_times(self, x, spaces, types):
47
48
        return self._times_helper(x, spaces, types,
                                  operation=lambda z: z.__rdiv__)
49
50

    def _adjoint_inverse_times(self, x, spaces, types):
51
52
        return self._times_helper(x, spaces, types,
                                  operation=lambda z: z.adjoint().__rdiv__)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    def diagonal(self, bare=False, copy=True):
        if bare:
            diagonal = self._diagonal.weight(power=-1)
        elif copy:
            diagonal = self._diagonal.copy()
        else:
            diagonal = self._diagonal
        return diagonal

    def inverse_diagonal(self, bare=False):
        return 1/self.diagonal(bare=bare, copy=False)

    def trace(self, bare=False):
        return self.diagonal(bare=bare, copy=False).sum()

    def inverse_trace(self, bare=False):
        return self.inverse_diagonal(bare=bare, copy=False).sum()

    def trace_log(self):
        log_diagonal = self.diagonal(copy=False).apply_scalar_function(np.log)
        return log_diagonal.sum()

    def determinant(self):
        return self.diagonal(copy=False).val.prod()

    def inverse_determinant(self):
        return 1/self.determinant()

    def log_determinant(self):
        return np.log(self.determinant())

    # ---Mandatory properties and methods---

87
88
89
90
    @property
    def implemented(self):
        return self._implemented

91
92
93
94
95
96
97
98
99
100
101
    @property
    def symmetric(self):
        return self._symmetric

    @property
    def unitary(self):
        return self._unitary

    # ---Added properties and methods---

    @property
102
103
    def distribution_strategy(self):
        return self._distribution_strategy
104

105
106
    def _parse_distribution_strategy(self, distribution_strategy, val):
        if distribution_strategy is None:
107
            if isinstance(val, distributed_data_object):
108
                distribution_strategy = val.distribution_strategy
109
            elif isinstance(val, Field):
110
                distribution_strategy = val.distribution_strategy
111
112
            else:
                about.warnings.cprint("WARNING: Datamodel set to default!")
113
114
                distribution_strategy = gc['default_distribution_strategy']
        elif distribution_strategy not in DISTRIBUTION_STRATEGIES['all']:
115
            raise ValueError(about._errors.cstring(
116
117
                    "ERROR: Invalid distribution_strategy!"))
        return distribution_strategy
118
119
120
121
122
123

    def set_diagonal(self, diagonal, bare=False, copy=True):
        # use the casting functionality from Field to process `diagonal`
        f = Field(domain=self.domain,
                  val=diagonal,
                  field_type=self.field_type,
124
                  distribution_strategy=self.distribution_strategy,
125
126
                  copy=copy)

127
128
129
130
131
132
133
134
135
136
        # weight if the given values were `bare` and `implemented` is True
        # do inverse weightening if the other way around
        if bare and self.implemented:
            # If `copy` is True, we won't change external data by weightening
            # Otherwise, inplace weightening would change the external field
            f.weight(inplace=copy)
        elif not bare and not self.implemented:
            # If `copy` is True, we won't change external data by weightening
            # Otherwise, inplace weightening would change the external field
            f.weight(inplace=copy, power=-1)
137
138
139
140
141
142
143
144
145

        # check if the operator is symmetric:
        self._symmetric = (f.val.imag == 0).all()

        # check if the operator is unitary:
        self._unitary = (f.val * f.val.conjugate() == 1).all()

        # store the diagonal-field
        self._diagonal = f
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

    def _times_helper(self, x, spaces, types, operation):
        # if the domain and field_type match directly
        # -> multiply the fields directly
        if x.domain == self.domain and x.field_type == self.field_type:
            # here the actual multiplication takes place
            return operation(self.diagonal(copy=False))(x)

        # if the distribution_strategy of self is sub-slice compatible to
        # the one of x, reshape the local data of self and apply it directly
        active_axes = []
        if spaces is None:
            for axes in x.domain_axes:
                active_axes += axes
        else:
            for space_index in spaces:
                active_axes += x.domain_axes[space_index]

        if types is None:
            for axes in x.field_type_axes:
                active_axes += axes
        else:
            for type_index in types:
                active_axes += x.field_type_axes[type_index]

        axes_local_distribution_strategy = \
            x.val.get_axes_local_distribution_strategy(active_axes)
        if axes_local_distribution_strategy == self.distribution_strategy:
            local_diagonal = self._diagonal.val.get_local_data(copy=False)
        else:
            # create an array that is sub-slice compatible
            redistr_diagonal_val = self._diagonal.val.copy(
                distribution_strategy=axes_local_distribution_strategy)
            local_diagonal = redistr_diagonal_val.get_local_data(copy=False)

        reshaper = [x.shape[i] if i in active_axes else 1
                    for i in xrange(len(x.shape))]
        reshaped_local_diagonal = np.reshape(local_diagonal, reshaper)

        # here the actual multiplication takes place
        local_result = operation(reshaped_local_diagonal)(
                           x.val.get_local_data(copy=False))

        result_field = x.copy_empty(dtype=local_result.dtype)
        result_field.val.set_local_data(local_result, copy=False)
        return result_field