diagonal_operator.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-

import numpy as np

from d2o import distributed_data_object,\
                STRATEGIES as DISTRIBUTION_STRATEGIES

from nifty.config import about,\
                         nifty_configuration as gc
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator


class DiagonalOperator(EndomorphicOperator):

    # ---Overwritten properties and methods---

    def __init__(self, domain=(), field_type=(), implemented=False,
                 diagonal=None, bare=False, datamodel=None, copy=True):
        super(DiagonalOperator, self).__init__(domain=domain,
                                               field_type=field_type,
                                               implemented=implemented)

        if datamodel is None:
            if isinstance(diagonal, distributed_data_object):
                datamodel = diagonal.distribution_strategy
            elif isinstance(diagonal, Field):
                datamodel = diagonal.datamodel

        self.datamodel = self._parse_datamodel(datamodel=datamodel,
                                               val=diagonal)

        self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)

    def _times(self, x, spaces, types):
        pass

    def _adjoint_times(self, x, spaces, types):
        pass

    def _inverse_times(self, x, spaces, types):
        pass

    def _adjoint_inverse_times(self, x, spaces, types):
        pass

    def _inverse_adjoint_times(self, x, spaces, types):
        pass

    def diagonal(self, bare=False, copy=True):
        if bare:
            diagonal = self._diagonal.weight(power=-1)
        elif copy:
            diagonal = self._diagonal.copy()
        else:
            diagonal = self._diagonal
        return diagonal

    def inverse_diagonal(self, bare=False):
        return 1/self.diagonal(bare=bare, copy=False)

    def trace(self, bare=False):
        return self.diagonal(bare=bare, copy=False).sum()

    def inverse_trace(self, bare=False):
        return self.inverse_diagonal(bare=bare, copy=False).sum()

    def trace_log(self):
        log_diagonal = self.diagonal(copy=False).apply_scalar_function(np.log)
        return log_diagonal.sum()

    def determinant(self):
        return self.diagonal(copy=False).val.prod()

    def inverse_determinant(self):
        return 1/self.determinant()

    def log_determinant(self):
        return np.log(self.determinant())

    # ---Mandatory properties and methods---

    @property
    def symmetric(self):
        return self._symmetric

    @property
    def unitary(self):
        return self._unitary

    # ---Added properties and methods---

    @property
    def datamodel(self):
        return self._datamodel

    def _parse_datamodel(self, datamodel, val):
        if datamodel is None:
            if isinstance(val, distributed_data_object):
                datamodel = val.distribution_strategy
            elif isinstance(val, Field):
                datamodel = val.datamodel
            else:
                about.warnings.cprint("WARNING: Datamodel set to default!")
                datamodel = gc['default_datamodel']
        elif datamodel not in DISTRIBUTION_STRATEGIES['all']:
            raise ValueError(about._errors.cstring(
                    "ERROR: Invalid datamodel!"))
        return datamodel

    def set_diagonal(self, diagonal, bare=False, copy=True):
        # use the casting functionality from Field to process `diagonal`
        f = Field(domain=self.domain,
                  val=diagonal,
                  field_type=self.field_type,
                  datamodel=self.datamodel,
                  copy=copy)

        # weight if the given values were `bare`
        f.weight(inplace=True)

        # check if the operator is symmetric:
        self._symmetric = (f.val.imag == 0).all()

        # check if the operator is unitary:
        self._unitary = (f.val * f.val.conjugate() == 1).all()

        # store the diagonal-field
        self._diagonal = f