probing.py 4.02 KB
 Martin Reinecke committed Feb 02, 2018 1 2 3 4 5 6 7 8 9 10 11 12 13 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . #  Martin Reinecke committed Jan 07, 2019 14 # Copyright(C) 2013-2019 Max-Planck-Society  Martin Reinecke committed Feb 02, 2018 15 #  Martin Reinecke committed Jan 07, 2019 16 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.  Martin Reinecke committed Feb 02, 2018 17   Martin Reinecke committed Aug 10, 2018 18 from .field import Field  Philipp Arras committed Jan 16, 2019 19 20 from .operators.endomorphic_operator import EndomorphicOperator from .operators.operator import Operator  Martin Reinecke committed Feb 02, 2018 21   Martin Reinecke committed Feb 13, 2018 22   Martin Reinecke committed Feb 02, 2018 23 class StatCalculator(object):  Martin Reinecke committed Jan 08, 2019 24 25 26 27 28 29  """Helper class to compute mean and variance of a set of inputs. Notes ----- - the memory usage of this object is constant, i.e. it does not increase with the number of samples added  Martin Reinecke committed Jan 08, 2019 30 31  - the code computes the unbiased variance (which contains a 1./(n-1) term for n samples).  Martin Reinecke committed Jan 08, 2019 32  """  Martin Reinecke committed Feb 02, 2018 33 34 35 36  def __init__(self): self._count = 0 def add(self, value):  Martin Reinecke committed Jan 08, 2019 37 38 39 40 41 42 43  """Adds a sample. Parameters ---------- value: any type that supports multiplication by a scalar and element-wise addition/subtraction/multiplication. """  Martin Reinecke committed Feb 02, 2018 44 45 46 47 48 49  self._count += 1 if self._count == 1: self._mean = 1.*value self._M2 = 0.*value else: delta = value - self._mean  Martin Reinecke committed Jul 04, 2018 50  self._mean = self.mean + delta*(1./self._count)  Martin Reinecke committed Feb 02, 2018 51  delta2 = value - self._mean  Martin Reinecke committed Jul 04, 2018 52  self._M2 = self._M2 + delta*delta2  Martin Reinecke committed Feb 02, 2018 53 54 55  @property def mean(self):  Martin Reinecke committed Jan 08, 2019 56 57 58  """ value type : the mean of all samples added so far. """  Martin Reinecke committed Feb 02, 2018 59 60 61 62 63 64  if self._count == 0: raise RuntimeError return 1.*self._mean @property def var(self):  Martin Reinecke committed Jan 08, 2019 65  """  Martin Reinecke committed Jan 08, 2019 66  value type : the unbiased variance of all samples added so far.  Martin Reinecke committed Jan 08, 2019 67  """  Martin Reinecke committed Feb 02, 2018 68 69 70 71 72 73  if self._count < 2: raise RuntimeError return self._M2 * (1./(self._count-1)) def probe_with_posterior_samples(op, post_op, nprobes):  Philipp Arras committed Jan 16, 2019 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96  '''FIXME Parameters ---------- op : EndomorphicOperator FIXME post_op : Operator FIXME nprobes : int Number of samples which shall be drawn. Returns ------- List of Field List of two fields: the mean and the variance. ''' if not isinstance(op, EndomorphicOperator): raise TypeError if post_op is not None: if not isinstance(post_op, Operator): raise TypeError if post_op.domain is not op.target: raise ValueError  Martin Reinecke committed Feb 02, 2018 97 98  sc = StatCalculator() for i in range(nprobes):  Martin Reinecke committed Apr 05, 2018 99 100 101 102  if post_op is None: sc.add(op.draw_sample(from_inverse=True)) else: sc.add(post_op(op.draw_sample(from_inverse=True)))  Martin Reinecke committed Feb 13, 2018 103 104 105  if nprobes == 1: return sc.mean, None  Martin Reinecke committed Feb 02, 2018 106  return sc.mean, sc.var  Martin Reinecke committed Feb 13, 2018 107 108 109  def probe_diagonal(op, nprobes, random_type="pm1"):  Martin Reinecke committed Jan 17, 2019 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130  '''Probes the diagonal of an endomorphic operator. The operator is called on a user-specified number of randomly generated input vectors :math:v_i, producing :math:r_i. The estimated diagonal is the mean of :math:r_i^\dagger v_i. Parameters ---------- op: EndomorphicOperator The operator to be probed. nprobes: int The number of probes to be used. random_type: str The kind of random number distribution to be used for the probing. The default value pm1 causes the probing vector to be randomly filled with values of +1 and -1. Returns ------- Field The estimated diagonal.  Philipp Arras committed Jan 16, 2019 131  '''  Martin Reinecke committed Feb 13, 2018 132 133 134 135 136 137  sc = StatCalculator() for i in range(nprobes): input = Field.from_random(random_type, op.domain) output = op(input) sc.add(output.conjugate()*input) return sc.mean