diagonal_prober_mixin.py 2.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18
19


20
21
class DiagonalProberMixin(object):
    def __init__(self, *args, **kwargs):
22
        self.reset()
23
        super(DiagonalProberMixin, self).__init__(*args, **kwargs)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    def reset(self):
        self.__sum_of_probings = 0
        self.__sum_of_squares = 0
        self.__diagonal = None
        self.__diagonal_variance = None
        super(DiagonalProberMixin, self).reset()

    def finish_probe(self, probe, pre_result):
        result = probe[1].conjugate()*pre_result
        self.__sum_of_probings += result
        if self.compute_variance:
            self.__sum_of_squares += result.conjugate() * result
        super(DiagonalProberMixin, self).finish_probe(probe, pre_result)

    @property
    def diagonal(self):
        if self.__diagonal is None:
            self.__diagonal = self.__sum_of_probings/self.probe_count
        return self.__diagonal

    @property
    def diagonal_variance(self):
        if not self.compute_variance:
            raise AttributeError("self.compute_variance is set to False")
        if self.__diagonal_variance is None:
            # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
            n = self.probe_count
            sum_pr = self.__sum_of_probings
            mean = self.diagonal
            sum_sq = self.__sum_of_squares

            self.__diagonal_variance = ((sum_sq - sum_pr*mean) / (n-1))
        return self.__diagonal_variance