Skip to content
Snippets Groups Projects
Commit 65d6f135 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak diagonal sampling/probing

parent 6bd3a182
Branches
Tags
No related merge requests found
Pipeline #
......@@ -106,6 +106,6 @@ if __name__ == "__main__":
ift.plot(ift.Field(plot_space,val=data.val), name='data.png', **plotdict)
ift.plot(ift.Field(plot_space,val=m.val), name='map.png', **plotdict)
# sampling the uncertainty map
mean, variance = ift.probe_with_posterior_samples(wiener_curvature, m_k, ht, 10)
mean, variance = ift.probe_with_posterior_samples(wiener_curvature, ht, 10)
ift.plot(ift.Field(plot_space, val=ift.sqrt(variance).val), name="uncertainty.png", **plotdict)
ift.plot(ift.Field(plot_space, val=mean.val), name="posterior_mean.png", **plotdict)
ift.plot(ift.Field(plot_space, val=(mean+m).val), name="posterior_mean.png", **plotdict)
......@@ -66,6 +66,6 @@ if __name__ == "__main__":
ift.plot(m, name="map.png", **plotdict)
# sampling the uncertainty map
mean, variance = ift.probe_with_posterior_samples(wiener_curvature, m_k, ht, 5)
mean, variance = ift.probe_with_posterior_samples(wiener_curvature, ht, 5)
ift.plot(ift.sqrt(variance), name="uncertainty.png", **plotdict)
ift.plot(mean, name="posterior_mean.png", **plotdict)
ift.plot(mean+m, name="posterior_mean.png", **plotdict)
......@@ -78,6 +78,11 @@ if __name__ == "__main__":
sample_variance = ift.Field.zeros(s_space)
sample_mean = ift.Field.zeros(s_space)
mean, variance = ift.probe_with_posterior_samples(curv, m, ht, 50)
mean, variance = ift.probe_with_posterior_samples(curv, ht, 50)
ift.plot(variance, name="posterior_variance.png", **plotdict)
ift.plot(mean, name="posterior_mean.png", **plotdict)
ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict)
# try to do the same with diagonal probing
variance = ift.probe_diagonal(ht*curv.inverse*ht.adjoint, 100)
#sm = ift.FFTSmoothingOperator(s_space, sigma=0.015)
ift.plot(variance, name="posterior_variance2.png", **plotdict)
......@@ -32,7 +32,7 @@ from .field import Field, sqrt, exp, log
from .probing.prober import Prober
from .probing.diagonal_prober_mixin import DiagonalProberMixin
from .probing.trace_prober_mixin import TraceProberMixin
from .probing.utils import probe_with_posterior_samples
from .probing.utils import probe_with_posterior_samples, probe_diagonal
from .minimization.line_search import LineSearch
from .minimization.line_search_strong_wolfe import LineSearchStrongWolfe
......
......@@ -11,13 +11,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object
from ..field import Field
class StatCalculator(object):
def __init__(self):
......@@ -47,12 +47,21 @@ class StatCalculator(object):
return self._M2 * (1./(self._count-1))
def probe_with_posterior_samples(op, m, post_op, nprobes):
def probe_with_posterior_samples(op, post_op, nprobes):
sc = StatCalculator()
for i in range(nprobes):
sample = post_op(op.draw_sample() + m)
sample = post_op(op.draw_sample())
sc.add(sample)
if nprobes == 1:
return sc.mean, None
return sc.mean, sc.var
def probe_diagonal(op, nprobes, random_type="normal"):
sc = StatCalculator()
for i in range(nprobes):
input = Field.from_random(random_type, op.domain)
output = op(input)
sc.add(output.conjugate()*input)
return sc.mean
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment