diff --git a/nifty5/minimization/metric_gaussian_kl.py b/nifty5/minimization/metric_gaussian_kl.py
index c42844a36287fa538eeee00b0e5278527df46b57..74f83a39409da890a888bb73e4c70cda93c80cd3 100644
--- a/nifty5/minimization/metric_gaussian_kl.py
+++ b/nifty5/minimization/metric_gaussian_kl.py
@@ -18,6 +18,8 @@
 from .. import utilities
 from ..linearization import Linearization
 from ..operators.energy_operators import StandardHamiltonian
+from ..probing import approximation2endo
+from ..sugar import makeOp
 from .energy import Energy
 
 
@@ -56,6 +58,9 @@ class MetricGaussianKL(Energy):
         as they are equally legitimate samples. If true, the number of used
         samples doubles. Mirroring samples stabilizes the KL estimate as
         extreme sample variation is counterbalanced. Default is False.
+    napprox : int
+        Number of samples for computing preconditioner for sampling. No
+        preconditioning is done by default.
     _samples : None
         Only a parameter for internal uses. Typically not to be set by users.
 
@@ -67,12 +72,13 @@ class MetricGaussianKL(Energy):
 
     See also
     --------
-    Metric Gaussian Variational Inference (FIXME in preparation)
+    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
+    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
     """
 
     def __init__(self, mean, hamiltonian, n_samples, constants=[],
                  point_estimates=[], mirror_samples=False,
-                 _samples=None):
+                 napprox=0, _samples=None):
         super(MetricGaussianKL, self).__init__(mean)
 
         if not isinstance(hamiltonian, StandardHamiltonian):
@@ -91,12 +97,15 @@ class MetricGaussianKL(Energy):
         if _samples is None:
             met = hamiltonian(Linearization.make_partial_var(
                 mean, point_estimates, True)).metric
+            if napprox > 1:
+                met._approximation = makeOp(approximation2endo(met, napprox))
             _samples = tuple(met.draw_sample(from_inverse=True)
                              for _ in range(n_samples))
             if mirror_samples:
                 _samples += tuple(-s for s in _samples)
         self._samples = _samples
 
+        # FIXME Use simplify for constant input instead
         self._lin = Linearization.make_partial_var(mean, constants)
         v, g = None, None
         for s in self._samples:
@@ -110,11 +119,12 @@ class MetricGaussianKL(Energy):
         self._val = v / len(self._samples)
         self._grad = g * (1./len(self._samples))
         self._metric = None
+        self._napprox = napprox
 
     def at(self, position):
         return MetricGaussianKL(position, self._hamiltonian, 0,
                                 self._constants, self._point_estimates,
-                                _samples=self._samples)
+                                napprox=self._napprox, _samples=self._samples)
 
     @property
     def value(self):
@@ -129,8 +139,12 @@ class MetricGaussianKL(Energy):
             lin = self._lin.with_want_metric()
             mymap = map(lambda v: self._hamiltonian(lin+v).metric,
                         self._samples)
-            self._metric = utilities.my_sum(mymap)
-            self._metric = self._metric.scale(1./len(self._samples))
+            self._unscaled_metric = utilities.my_sum(mymap)
+            self._metric = self._unscaled_metric.scale(1./len(self._samples))
+
+    def unscaled_metric(self):
+        self._get_metric()
+        return self._unscaled_metric, 1/len(self._samples)
 
     def apply_metric(self, x):
         self._get_metric()
diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py
index 80905fbfb395930367f8af57d042939b692340ec..9fd8efab3fbc08f1295bfa80812ee2b949b651a0 100644
--- a/nifty5/operators/simple_linear_operators.py
+++ b/nifty5/operators/simple_linear_operators.py
@@ -326,7 +326,7 @@ class NullOperator(LinearOperator):
         return self._nullfield(self._tgt(mode))
 
 
-class _PartialExtractor(LinearOperator):
+class PartialExtractor(LinearOperator):
     def __init__(self, domain, target):
         if not isinstance(domain, MultiDomain):
             raise TypeError("MultiDomain expected")
@@ -335,7 +335,7 @@ class _PartialExtractor(LinearOperator):
         self._domain = domain
         self._target = target
         for key in self._target.keys():
-            if not (self._domain[key] is not self._target[key]):
+            if self._domain[key] is not self._target[key]:
                 raise ValueError("domain mismatch")
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
diff --git a/nifty5/probing.py b/nifty5/probing.py
index e5c06392258ba5286868ca63217f7789bdea9d9e..2c1ad8a3fae03ce3f6cdd4970819cbfdc5385049 100644
--- a/nifty5/probing.py
+++ b/nifty5/probing.py
@@ -15,9 +15,10 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
+from .multi_field import MultiField
 from .operators.endomorphic_operator import EndomorphicOperator
 from .operators.operator import Operator
-from .sugar import from_random
+from .sugar import from_global_data, from_random
 
 
 class StatCalculator(object):
@@ -134,3 +135,17 @@ def probe_diagonal(op, nprobes, random_type="pm1"):
         x = from_random(random_type, op.domain)
         sc.add(op(x).conjugate()*x)
     return sc.mean
+
+
+def approximation2endo(op, nsamples):
+    print('Calculate preconditioner')
+    sc = StatCalculator()
+    for _ in range(nsamples):
+        sc.add(op.draw_sample())
+    approx = sc.var
+    dct = approx.to_dict()
+    for kk in dct:
+        foo = dct[kk].to_global_data_rw()
+        foo[foo == 0] = 1
+        dct[kk] = from_global_data(dct[kk].domain, foo)
+    return MultiField.from_dict(dct)
diff --git a/test/test_kl.py b/test/test_kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..79428c5b5c0a176109687e6ba01a06a562f72f38
--- /dev/null
+++ b/test/test_kl.py
@@ -0,0 +1,82 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2019 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
+
+import numpy as np
+
+import nifty5 as ift
+from numpy.testing import assert_, assert_allclose
+import pytest
+
+pmp = pytest.mark.parametrize
+
+
+@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
+@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
+@pmp('mirror_samples', (True, False))
+def test_kl(constants, point_estimates, mirror_samples):
+    np.random.seed(42)
+    dom = ift.RGSpace((12,), (2.12))
+    op0 = ift.HarmonicSmoothingOperator(dom, 3)
+    op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b'))
+    lh = ift.GaussianEnergy(domain=op.target) @ op
+    ic = ift.GradientNormController(iteration_limit=5)
+    h = ift.StandardHamiltonian(lh, ic_samp=ic)
+    mean0 = ift.from_random('normal', h.domain)
+
+    nsamps = 2
+    kl = ift.MetricGaussianKL(mean0,
+                              h,
+                              nsamps,
+                              constants=constants,
+                              point_estimates=point_estimates,
+                              mirror_samples=mirror_samples,
+                              napprox=0)
+    klpure = ift.MetricGaussianKL(mean0,
+                                  h,
+                                  nsamps,
+                                  mirror_samples=mirror_samples,
+                                  napprox=0,
+                                  _samples=kl.samples)
+
+    # Test value
+    assert_allclose(kl.value, klpure.value)
+
+    # Test gradient
+    for kk in h.domain.keys():
+        res0 = klpure.gradient.to_global_data()[kk]
+        if kk in constants:
+            res0 = 0*res0
+        res1 = kl.gradient.to_global_data()[kk]
+        assert_allclose(res0, res1)
+
+    # Test number of samples
+    expected_nsamps = 2*nsamps if mirror_samples else nsamps
+    assert_(len(kl.samples) == expected_nsamps)
+
+    # Test point_estimates (after drawing samples)
+    for kk in point_estimates:
+        for ss in kl.samples:
+            ss = ss.to_global_data()[kk]
+            assert_allclose(ss, 0*ss)
+
+    # Test constants (after some minimization)
+    cg = ift.GradientNormController(iteration_limit=5)
+    minimizer = ift.NewtonCG(cg)
+    kl, _ = minimizer(kl)
+    diff = (mean0 - kl.position).to_global_data()
+    for kk in constants:
+        assert_allclose(diff[kk], 0*diff[kk])