probing.py 4.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
from .multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
19
20
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.operator import Operator
21
from .sugar import from_global_data, from_random
22

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
23

24
class StatCalculator(object):
Martin Reinecke's avatar
Martin Reinecke committed
25
26
27
28
    """Helper class to compute mean and variance of a set of inputs.

    Notes
    -----
Philipp Arras's avatar
Docs    
Philipp Arras committed
29
30
31
    - The memory usage of this object is constant, i.e. it does not increase
      with the number of samples added.
    - The code computes the unbiased variance (which contains a `1./(n-1)`
Martin Reinecke's avatar
Martin Reinecke committed
32
      term for `n` samples).
Martin Reinecke's avatar
Martin Reinecke committed
33
    """
34
35
36
37
    def __init__(self):
        self._count = 0

    def add(self, value):
Martin Reinecke's avatar
Martin Reinecke committed
38
39
40
41
42
43
44
        """Adds a sample.

        Parameters
        ----------
        value: any type that supports multiplication by a scalar and
               element-wise addition/subtraction/multiplication.
        """
45
46
47
48
49
50
        self._count += 1
        if self._count == 1:
            self._mean = 1.*value
            self._M2 = 0.*value
        else:
            delta = value - self._mean
51
            self._mean = self.mean + delta*(1./self._count)
52
            delta2 = value - self._mean
53
            self._M2 = self._M2 + delta*delta2
54
55
56

    @property
    def mean(self):
Martin Reinecke's avatar
Martin Reinecke committed
57
58
59
        """
        value type : the mean of all samples added so far.
        """
60
61
62
63
64
65
        if self._count == 0:
            raise RuntimeError
        return 1.*self._mean

    @property
    def var(self):
Martin Reinecke's avatar
Martin Reinecke committed
66
        """
Martin Reinecke's avatar
Martin Reinecke committed
67
        value type : the unbiased variance of all samples added so far.
Martin Reinecke's avatar
Martin Reinecke committed
68
        """
69
70
71
72
73
74
        if self._count < 2:
            raise RuntimeError
        return self._M2 * (1./(self._count-1))


def probe_with_posterior_samples(op, post_op, nprobes):
Philipp Arras's avatar
Philipp Arras committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    '''FIXME

    Parameters
    ----------
    op : EndomorphicOperator
        FIXME
    post_op : Operator
        FIXME
    nprobes : int
        Number of samples which shall be drawn.

    Returns
    -------
    List of Field
        List of two fields: the mean and the variance.
    '''
    if not isinstance(op, EndomorphicOperator):
        raise TypeError
    if post_op is not None:
        if not isinstance(post_op, Operator):
            raise TypeError
        if post_op.domain is not op.target:
            raise ValueError
98
99
    sc = StatCalculator()
    for i in range(nprobes):
Martin Reinecke's avatar
Martin Reinecke committed
100
101
102
103
        if post_op is None:
            sc.add(op.draw_sample(from_inverse=True))
        else:
            sc.add(post_op(op.draw_sample(from_inverse=True)))
Martin Reinecke's avatar
Martin Reinecke committed
104
105
106

    if nprobes == 1:
        return sc.mean, None
107
    return sc.mean, sc.var
Martin Reinecke's avatar
Martin Reinecke committed
108
109
110


def probe_diagonal(op, nprobes, random_type="pm1"):
111
112
113
114
    '''Probes the diagonal of an endomorphic operator.

    The operator is called on a user-specified number of randomly generated
    input vectors :math:`v_i`, producing :math:`r_i`. The estimated diagonal
Martin Reinecke's avatar
Martin Reinecke committed
115
    is the mean of :math:`r_i^\\dagger v_i`.
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    Parameters
    ----------
    op: EndomorphicOperator
        The operator to be probed.
    nprobes: int
        The number of probes to be used.
    random_type: str
        The kind of random number distribution to be used for the probing.
        The default value `pm1` causes the probing vector to be randomly
        filled with values of +1 and -1.

    Returns
    -------
    Field
        The estimated diagonal.
Philipp Arras's avatar
Philipp Arras committed
132
    '''
Martin Reinecke's avatar
Martin Reinecke committed
133
134
    sc = StatCalculator()
    for i in range(nprobes):
135
136
        x = from_random(random_type, op.domain)
        sc.add(op(x).conjugate()*x)
Martin Reinecke's avatar
Martin Reinecke committed
137
    return sc.mean
138
139
140
141
142
143
144
145
146
147
148
149
150
151


def approximation2endo(op, nsamples):
    print('Calculate preconditioner')
    sc = StatCalculator()
    for _ in range(nsamples):
        sc.add(op.draw_sample())
    approx = sc.var
    dct = approx.to_dict()
    for kk in dct:
        foo = dct[kk].to_global_data_rw()
        foo[foo == 0] = 1
        dct[kk] = from_global_data(dct[kk].domain, foo)
    return MultiField.from_dict(dct)