trace_prober_mixin.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19 20
from __future__ import division
from builtins import object
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..sugar import create_composed_fft_operator
22

23

24 25
class TraceProberMixin(object):
    def __init__(self, *args, **kwargs):
26
        self.reset()
27
        self.__evaluate_probe_in_signal_space = False
28
        super(TraceProberMixin, self).__init__(*args, **kwargs)
29 30 31 32 33 34 35 36 37

    def reset(self):
        self.__sum_of_probings = 0
        self.__sum_of_squares = 0
        self.__trace = None
        self.__trace_variance = None
        super(TraceProberMixin, self).reset()

    def finish_probe(self, probe, pre_result):
38 39
        if self.__evaluate_probe_in_signal_space:
            fft = create_composed_fft_operator(self._domain, all_to='position')
Martin Reinecke's avatar
Martin Reinecke committed
40
            result = fft(probe[1]).weight(-1).vdot(fft(pre_result))
41
        else:
Martin Reinecke's avatar
Martin Reinecke committed
42
            result = probe[1].weight(-1).vdot(pre_result)
43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        self.__sum_of_probings += result
        if self.compute_variance:
            self.__sum_of_squares += result.conjugate() * result
        super(TraceProberMixin, self).finish_probe(probe, pre_result)

    @property
    def trace(self):
        if self.__trace is None:
            self.__trace = self.__sum_of_probings/self.probe_count
        return self.__trace

    @property
    def trace_variance(self):
        if not self.compute_variance:
            raise AttributeError("self.compute_variance is set to False")
        if self.__trace_variance is None:
            # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
            n = self.probe_count
            sum_pr = self.__sum_of_probings
            mean = self.trace
            sum_sq = self.__sum_of_squares

Martin Reinecke's avatar
Martin Reinecke committed
66
            self.__trace_variance = (sum_sq - sum_pr*mean) / (n-1)
67
        return self.__trace_variance