probing.py 4.43 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
from .multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
19 20
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.operator import Operator
21
from .sugar import from_global_data, from_random
22

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
23

24
class StatCalculator(object):
Martin Reinecke's avatar
Martin Reinecke committed
25 26 27 28
    """Helper class to compute mean and variance of a set of inputs.

    Notes
    -----
Philipp Arras's avatar
Docs  
Philipp Arras committed
29 30 31
    - The memory usage of this object is constant, i.e. it does not increase
      with the number of samples added.
    - The code computes the unbiased variance (which contains a `1./(n-1)`
Martin Reinecke's avatar
Martin Reinecke committed
32
      term for `n` samples).
Martin Reinecke's avatar
Martin Reinecke committed
33
    """
34 35 36 37
    def __init__(self):
        self._count = 0

    def add(self, value):
Martin Reinecke's avatar
Martin Reinecke committed
38 39 40 41 42 43 44
        """Adds a sample.

        Parameters
        ----------
        value: any type that supports multiplication by a scalar and
               element-wise addition/subtraction/multiplication.
        """
45 46 47 48 49 50
        self._count += 1
        if self._count == 1:
            self._mean = 1.*value
            self._M2 = 0.*value
        else:
            delta = value - self._mean
51
            self._mean = self.mean + delta*(1./self._count)
52
            delta2 = value - self._mean
53
            self._M2 = self._M2 + delta*delta2
54 55 56

    @property
    def mean(self):
Martin Reinecke's avatar
Martin Reinecke committed
57 58 59
        """
        value type : the mean of all samples added so far.
        """
60 61 62 63 64 65
        if self._count == 0:
            raise RuntimeError
        return 1.*self._mean

    @property
    def var(self):
Martin Reinecke's avatar
Martin Reinecke committed
66
        """
Martin Reinecke's avatar
Martin Reinecke committed
67
        value type : the unbiased variance of all samples added so far.
Martin Reinecke's avatar
Martin Reinecke committed
68
        """
69 70 71 72 73 74
        if self._count < 2:
            raise RuntimeError
        return self._M2 * (1./(self._count-1))


def probe_with_posterior_samples(op, post_op, nprobes):
Philipp Arras's avatar
Philipp Arras committed
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    '''FIXME

    Parameters
    ----------
    op : EndomorphicOperator
        FIXME
    post_op : Operator
        FIXME
    nprobes : int
        Number of samples which shall be drawn.

    Returns
    -------
    List of Field
        List of two fields: the mean and the variance.
    '''
    if not isinstance(op, EndomorphicOperator):
        raise TypeError
    if post_op is not None:
        if not isinstance(post_op, Operator):
            raise TypeError
        if post_op.domain is not op.target:
            raise ValueError
98 99
    sc = StatCalculator()
    for i in range(nprobes):
Martin Reinecke's avatar
Martin Reinecke committed
100 101 102 103
        if post_op is None:
            sc.add(op.draw_sample(from_inverse=True))
        else:
            sc.add(post_op(op.draw_sample(from_inverse=True)))
Martin Reinecke's avatar
Martin Reinecke committed
104 105 106

    if nprobes == 1:
        return sc.mean, None
107
    return sc.mean, sc.var
Martin Reinecke's avatar
Martin Reinecke committed
108 109 110


def probe_diagonal(op, nprobes, random_type="pm1"):
111 112 113 114
    '''Probes the diagonal of an endomorphic operator.

    The operator is called on a user-specified number of randomly generated
    input vectors :math:`v_i`, producing :math:`r_i`. The estimated diagonal
Martin Reinecke's avatar
Martin Reinecke committed
115
    is the mean of :math:`r_i^\\dagger v_i`.
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

    Parameters
    ----------
    op: EndomorphicOperator
        The operator to be probed.
    nprobes: int
        The number of probes to be used.
    random_type: str
        The kind of random number distribution to be used for the probing.
        The default value `pm1` causes the probing vector to be randomly
        filled with values of +1 and -1.

    Returns
    -------
    Field
        The estimated diagonal.
Philipp Arras's avatar
Philipp Arras committed
132
    '''
Martin Reinecke's avatar
Martin Reinecke committed
133 134
    sc = StatCalculator()
    for i in range(nprobes):
135 136
        x = from_random(random_type, op.domain)
        sc.add(op(x).conjugate()*x)
Martin Reinecke's avatar
Martin Reinecke committed
137
    return sc.mean
138 139 140 141 142 143 144 145 146 147 148 149 150 151


def approximation2endo(op, nsamples):
    print('Calculate preconditioner')
    sc = StatCalculator()
    for _ in range(nsamples):
        sc.add(op.draw_sample())
    approx = sc.var
    dct = approx.to_dict()
    for kk in dct:
        foo = dct[kk].to_global_data_rw()
        foo[foo == 0] = 1
        dct[kk] = from_global_data(dct[kk].domain, foo)
    return MultiField.from_dict(dct)