diagonal_operator.py 4.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# -*- coding: utf-8 -*-

import numpy as np

from d2o import distributed_data_object,\
                STRATEGIES as DISTRIBUTION_STRATEGIES

from nifty.config import about,\
                         nifty_configuration as gc
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator


class DiagonalOperator(EndomorphicOperator):

    # ---Overwritten properties and methods---

    def __init__(self, domain=(), field_type=(), implemented=False,
19
                 diagonal=None, bare=False, copy=True, datamodel=None):
20
21
22
23
        super(DiagonalOperator, self).__init__(domain=domain,
                                               field_type=field_type,
                                               implemented=implemented)

24
25
        self._implemented = bool(implemented)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        if datamodel is None:
            if isinstance(diagonal, distributed_data_object):
                datamodel = diagonal.distribution_strategy
            elif isinstance(diagonal, Field):
                datamodel = diagonal.datamodel

        self.datamodel = self._parse_datamodel(datamodel=datamodel,
                                               val=diagonal)

        self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)

    def _times(self, x, spaces, types):
        pass

    def _adjoint_times(self, x, spaces, types):
        pass

    def _inverse_times(self, x, spaces, types):
        pass

    def _adjoint_inverse_times(self, x, spaces, types):
        pass

    def _inverse_adjoint_times(self, x, spaces, types):
        pass

    def diagonal(self, bare=False, copy=True):
        if bare:
            diagonal = self._diagonal.weight(power=-1)
        elif copy:
            diagonal = self._diagonal.copy()
        else:
            diagonal = self._diagonal
        return diagonal

    def inverse_diagonal(self, bare=False):
        return 1/self.diagonal(bare=bare, copy=False)

    def trace(self, bare=False):
        return self.diagonal(bare=bare, copy=False).sum()

    def inverse_trace(self, bare=False):
        return self.inverse_diagonal(bare=bare, copy=False).sum()

    def trace_log(self):
        log_diagonal = self.diagonal(copy=False).apply_scalar_function(np.log)
        return log_diagonal.sum()

    def determinant(self):
        return self.diagonal(copy=False).val.prod()

    def inverse_determinant(self):
        return 1/self.determinant()

    def log_determinant(self):
        return np.log(self.determinant())

    # ---Mandatory properties and methods---

85
86
87
88
    @property
    def implemented(self):
        return self._implemented

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    @property
    def symmetric(self):
        return self._symmetric

    @property
    def unitary(self):
        return self._unitary

    # ---Added properties and methods---

    @property
    def datamodel(self):
        return self._datamodel

    def _parse_datamodel(self, datamodel, val):
        if datamodel is None:
            if isinstance(val, distributed_data_object):
                datamodel = val.distribution_strategy
            elif isinstance(val, Field):
                datamodel = val.datamodel
            else:
                about.warnings.cprint("WARNING: Datamodel set to default!")
                datamodel = gc['default_datamodel']
        elif datamodel not in DISTRIBUTION_STRATEGIES['all']:
            raise ValueError(about._errors.cstring(
                    "ERROR: Invalid datamodel!"))
        return datamodel

    def set_diagonal(self, diagonal, bare=False, copy=True):
        # use the casting functionality from Field to process `diagonal`
        f = Field(domain=self.domain,
                  val=diagonal,
                  field_type=self.field_type,
                  datamodel=self.datamodel,
                  copy=copy)

125
126
127
128
129
130
131
132
133
134
        # weight if the given values were `bare` and `implemented` is True
        # do inverse weightening if the other way around
        if bare and self.implemented:
            # If `copy` is True, we won't change external data by weightening
            # Otherwise, inplace weightening would change the external field
            f.weight(inplace=copy)
        elif not bare and not self.implemented:
            # If `copy` is True, we won't change external data by weightening
            # Otherwise, inplace weightening would change the external field
            f.weight(inplace=copy, power=-1)
135
136
137
138
139
140
141
142
143

        # check if the operator is symmetric:
        self._symmetric = (f.val.imag == 0).all()

        # check if the operator is unitary:
        self._unitary = (f.val * f.val.conjugate() == 1).all()

        # store the diagonal-field
        self._diagonal = f