probing.py 4.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
from .multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
19
20
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.operator import Operator
Martin Reinecke's avatar
Martin Reinecke committed
21
from .sugar import makeField, from_random
22

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
23

24
class StatCalculator(object):
Martin Reinecke's avatar
Martin Reinecke committed
25
26
27
28
    """Helper class to compute mean and variance of a set of inputs.

    Notes
    -----
Philipp Arras's avatar
Docs    
Philipp Arras committed
29
30
31
    - The memory usage of this object is constant, i.e. it does not increase
      with the number of samples added.
    - The code computes the unbiased variance (which contains a `1./(n-1)`
Martin Reinecke's avatar
Martin Reinecke committed
32
      term for `n` samples).
Martin Reinecke's avatar
Martin Reinecke committed
33
    """
34
35
36
37
    def __init__(self):
        self._count = 0

    def add(self, value):
Martin Reinecke's avatar
Martin Reinecke committed
38
39
40
41
42
43
44
        """Adds a sample.

        Parameters
        ----------
        value: any type that supports multiplication by a scalar and
               element-wise addition/subtraction/multiplication.
        """
45
46
47
48
49
50
        self._count += 1
        if self._count == 1:
            self._mean = 1.*value
            self._M2 = 0.*value
        else:
            delta = value - self._mean
51
            self._mean = self.mean + delta*(1./self._count)
52
            delta2 = value - self._mean
53
            self._M2 = self._M2 + delta*delta2
54
55
56

    @property
    def mean(self):
Martin Reinecke's avatar
Martin Reinecke committed
57
58
59
        """
        value type : the mean of all samples added so far.
        """
60
61
62
63
64
65
        if self._count == 0:
            raise RuntimeError
        return 1.*self._mean

    @property
    def var(self):
Martin Reinecke's avatar
Martin Reinecke committed
66
        """
Martin Reinecke's avatar
Martin Reinecke committed
67
        value type : the unbiased variance of all samples added so far.
Martin Reinecke's avatar
Martin Reinecke committed
68
        """
69
70
71
72
73
        if self._count < 2:
            raise RuntimeError
        return self._M2 * (1./(self._count-1))


Martin Reinecke's avatar
Martin Reinecke committed
74
def probe_with_posterior_samples(op, post_op, nprobes, dtype):
Philipp Arras's avatar
Philipp Arras committed
75
76
77
78
79
80
81
82
83
84
    '''FIXME

    Parameters
    ----------
    op : EndomorphicOperator
        FIXME
    post_op : Operator
        FIXME
    nprobes : int
        Number of samples which shall be drawn.
Martin Reinecke's avatar
Martin Reinecke committed
85
86
    dtype :
        the data type of the samples
Philipp Arras's avatar
Philipp Arras committed
87
88
89
90
91
92
93
94
95
96
97
98
99

    Returns
    -------
    List of Field
        List of two fields: the mean and the variance.
    '''
    if not isinstance(op, EndomorphicOperator):
        raise TypeError
    if post_op is not None:
        if not isinstance(post_op, Operator):
            raise TypeError
        if post_op.domain is not op.target:
            raise ValueError
100
101
    sc = StatCalculator()
    for i in range(nprobes):
Martin Reinecke's avatar
Martin Reinecke committed
102
        if post_op is None:
Philipp Arras's avatar
Philipp Arras committed
103
            sc.add(op.draw_sample(from_inverse=True))
Martin Reinecke's avatar
Martin Reinecke committed
104
        else:
Philipp Arras's avatar
Philipp Arras committed
105
            sc.add(post_op(op.draw_sample(from_inverse=True)))
Martin Reinecke's avatar
Martin Reinecke committed
106
107
108

    if nprobes == 1:
        return sc.mean, None
109
    return sc.mean, sc.var
Martin Reinecke's avatar
Martin Reinecke committed
110
111
112


def probe_diagonal(op, nprobes, random_type="pm1"):
113
114
115
116
    '''Probes the diagonal of an endomorphic operator.

    The operator is called on a user-specified number of randomly generated
    input vectors :math:`v_i`, producing :math:`r_i`. The estimated diagonal
Martin Reinecke's avatar
Martin Reinecke committed
117
    is the mean of :math:`r_i^\\dagger v_i`.
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    Parameters
    ----------
    op: EndomorphicOperator
        The operator to be probed.
    nprobes: int
        The number of probes to be used.
    random_type: str
        The kind of random number distribution to be used for the probing.
        The default value `pm1` causes the probing vector to be randomly
        filled with values of +1 and -1.

    Returns
    -------
    Field
        The estimated diagonal.
Philipp Arras's avatar
Philipp Arras committed
134
    '''
Martin Reinecke's avatar
Martin Reinecke committed
135
136
    sc = StatCalculator()
    for i in range(nprobes):
137
        x = from_random(op.domain, random_type)
138
        sc.add(op(x).conjugate()*x)
Martin Reinecke's avatar
Martin Reinecke committed
139
    return sc.mean
140
141
142
143
144
145
146
147
148


def approximation2endo(op, nsamples):
    sc = StatCalculator()
    for _ in range(nsamples):
        sc.add(op.draw_sample())
    approx = sc.var
    dct = approx.to_dict()
    for kk in dct:
Martin Reinecke's avatar
Martin Reinecke committed
149
        foo = dct[kk].val_rw()
150
        foo[foo == 0] = 1
Martin Reinecke's avatar
Martin Reinecke committed
151
        dct[kk] = makeField(dct[kk].domain, foo)
152
    return MultiField.from_dict(dct)