From e25736eef2c7a93fd81daf37a90a337fac02fd02 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sun, 26 Apr 2020 13:48:12 +0200
Subject: [PATCH] impoement feedback

---
 demos/getting_started_0.ipynb |  4 ++--
 nifty6/probing.py             | 10 +++++-----
 nifty6/sugar.py               |  3 +--
 3 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/demos/getting_started_0.ipynb b/demos/getting_started_0.ipynb
index 6276134ba..ff2cddb70 100644
--- a/demos/getting_started_0.ipynb
+++ b/demos/getting_started_0.ipynb
@@ -471,7 +471,7 @@
    },
    "outputs": [],
    "source": [
-    "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200)"
+    "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200, np.float64)"
    ]
   },
   {
@@ -598,7 +598,7 @@
     "m = D(j)\n",
     "\n",
     "# Uncertainty\n",
-    "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20)\n",
+    "m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20, np.float64)\n",
     "\n",
     "# Get data\n",
     "s_data = HT(sh).val\n",
diff --git a/nifty6/probing.py b/nifty6/probing.py
index 6aeb98bc8..cf08e9d13 100644
--- a/nifty6/probing.py
+++ b/nifty6/probing.py
@@ -71,7 +71,7 @@ class StatCalculator(object):
         return self._M2 * (1./(self._count-1))
 
 
-def probe_with_posterior_samples(op, post_op, nprobes):
+def probe_with_posterior_samples(op, post_op, nprobes, dtype):
     '''FIXME
 
     Parameters
@@ -82,6 +82,8 @@ def probe_with_posterior_samples(op, post_op, nprobes):
         FIXME
     nprobes : int
         Number of samples which shall be drawn.
+    dtype :
+        the data type of the samples
 
     Returns
     -------
@@ -97,12 +99,10 @@ def probe_with_posterior_samples(op, post_op, nprobes):
             raise ValueError
     sc = StatCalculator()
     for i in range(nprobes):
-        # FIXME which dtype should we use here?
-        import numpy as np
         if post_op is None:
-            sc.add(op.draw_sample(dtype=np.float64, from_inverse=True))
+            sc.add(op.draw_sample(dtype=dtype, from_inverse=True))
         else:
-            sc.add(post_op(op.draw_sample(dtype=np.float64, from_inverse=True)))
+            sc.add(post_op(op.draw_sample(dtype=dtype, from_inverse=True)))
 
     if nprobes == 1:
         return sc.mean, None
diff --git a/nifty6/sugar.py b/nifty6/sugar.py
index a1bcdef58..b0ffe1efa 100644
--- a/nifty6/sugar.py
+++ b/nifty6/sugar.py
@@ -498,8 +498,7 @@ def calculate_position(operator, output):
     else:
         cov = 1e-3*output.val.max()**2
     invcov = ScalingOperator(output.domain, cov).inverse
-    # FIXME!!!
-    d = output + invcov.draw_sample(dtype=np.float64, from_inverse=True)
+    d = output + invcov.draw_sample(dtype=output.dtype, from_inverse=True)
     lh = GaussianEnergy(d, invcov) @ operator
     H = StandardHamiltonian(
         lh, ic_samp=GradientNormController(iteration_limit=200))
-- 
GitLab