From e25736eef2c7a93fd81daf37a90a337fac02fd02 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Sun, 26 Apr 2020 13:48:12 +0200 Subject: [PATCH] impoement feedback --- demos/getting_started_0.ipynb | 4 ++-- nifty6/probing.py | 10 +++++----- nifty6/sugar.py | 3 +-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/demos/getting_started_0.ipynb b/demos/getting_started_0.ipynb index 6276134ba..ff2cddb70 100644 --- a/demos/getting_started_0.ipynb +++ b/demos/getting_started_0.ipynb @@ -471,7 +471,7 @@ }, "outputs": [], "source": [ - "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200)" + "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200, np.float64)" ] }, { @@ -598,7 +598,7 @@ "m = D(j)\n", "\n", "# Uncertainty\n", - "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20)\n", + "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20, np.float64)\n", "\n", "# Get data\n", "s_data = HT(sh).val\n", diff --git a/nifty6/probing.py b/nifty6/probing.py index 6aeb98bc8..cf08e9d13 100644 --- a/nifty6/probing.py +++ b/nifty6/probing.py @@ -71,7 +71,7 @@ class StatCalculator(object): return self._M2 * (1./(self._count-1)) -def probe_with_posterior_samples(op, post_op, nprobes): +def probe_with_posterior_samples(op, post_op, nprobes, dtype): '''FIXME Parameters @@ -82,6 +82,8 @@ def probe_with_posterior_samples(op, post_op, nprobes): FIXME nprobes : int Number of samples which shall be drawn. + dtype : + the data type of the samples Returns ------- @@ -97,12 +99,10 @@ def probe_with_posterior_samples(op, post_op, nprobes): raise ValueError sc = StatCalculator() for i in range(nprobes): - # FIXME which dtype should we use here? - import numpy as np if post_op is None: - sc.add(op.draw_sample(dtype=np.float64, from_inverse=True)) + sc.add(op.draw_sample(dtype=dtype, from_inverse=True)) else: - sc.add(post_op(op.draw_sample(dtype=np.float64, from_inverse=True))) + sc.add(post_op(op.draw_sample(dtype=dtype, from_inverse=True))) if nprobes == 1: return sc.mean, None diff --git a/nifty6/sugar.py b/nifty6/sugar.py index a1bcdef58..b0ffe1efa 100644 --- a/nifty6/sugar.py +++ b/nifty6/sugar.py @@ -498,8 +498,7 @@ def calculate_position(operator, output): else: cov = 1e-3*output.val.max()**2 invcov = ScalingOperator(output.domain, cov).inverse - # FIXME!!! - d = output + invcov.draw_sample(dtype=np.float64, from_inverse=True) + d = output + invcov.draw_sample(dtype=output.dtype, from_inverse=True) lh = GaussianEnergy(d, invcov) @ operator H = StandardHamiltonian( lh, ic_samp=GradientNormController(iteration_limit=200)) -- GitLab