trace_prober_mixin.py 2.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19


20
21
class TraceProberMixin(object):
    def __init__(self, *args, **kwargs):
22
        self.reset()
23
        super(TraceProberMixin, self).__init__(*args, **kwargs)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    def reset(self):
        self.__sum_of_probings = 0
        self.__sum_of_squares = 0
        self.__trace = None
        self.__trace_variance = None
        super(TraceProberMixin, self).reset()

    def finish_probe(self, probe, pre_result):
        result = probe[1].dot(pre_result, bare=True)
        self.__sum_of_probings += result
        if self.compute_variance:
            self.__sum_of_squares += result.conjugate() * result
        super(TraceProberMixin, self).finish_probe(probe, pre_result)

    @property
    def trace(self):
        if self.__trace is None:
            self.__trace = self.__sum_of_probings/self.probe_count
        return self.__trace

    @property
    def trace_variance(self):
        if not self.compute_variance:
            raise AttributeError("self.compute_variance is set to False")
        if self.__trace_variance is None:
            # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
            n = self.probe_count
            sum_pr = self.__sum_of_probings
            mean = self.trace
            sum_sq = self.__sum_of_squares

            self.__trace_variance = ((sum_sq - sum_pr*mean) / (n-1))
        return self.__trace_variance