diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py
index 4b9b7e59ab180f39e522b771552e75117397964b..65d5c66537c9201f251c0e9c667ccfd441c83031 100644
--- a/demos/getting_started_3.py
+++ b/demos/getting_started_3.py
@@ -72,8 +72,8 @@ if __name__ == '__main__':
 
     # set up minimization and inversion schemes
     ic_sampling = ift.GradientNormController(iteration_limit=100)
-    ic_newton = ift.DeltaEnergyController(
-        name='Newton', tol_rel_deltaE=1e-8, iteration_limit=100)
+    ic_newton = ift.GradInfNormController(
+        name='Newton', tol=1e-7, iteration_limit=1000)
     minimizer = ift.NewtonCG(ic_newton)
 
     # build model Hamiltonian
@@ -91,7 +91,7 @@ if __name__ == '__main__':
     # number of samples used to estimate the KL
     N_samples = 20
     for i in range(2):
-        KL = ift.KL_Energy(position, H, N_samples, want_metric=True)
+        KL = ift.KL_Energy(position, H, N_samples)
         KL, convergence = minimizer(KL)
         position = KL.position
 
diff --git a/nifty5/linearization.py b/nifty5/linearization.py
index c27dec1abc65794ab57288480f49e59b95053d12..30a9b57bad11b5e264a0997bde6c0e0e2bbefe10 100644
--- a/nifty5/linearization.py
+++ b/nifty5/linearization.py
@@ -152,6 +152,9 @@ class Linearization(object):
     def add_metric(self, metric):
         return self.new(self._val, self._jac, metric)
 
+    def with_want_metric(self):
+        return Linearization(self._val, self._jac, self._metric, True)
+
     @staticmethod
     def make_var(field, want_metric=False):
         from .operators.scaling_operator import ScalingOperator
@@ -163,3 +166,15 @@ class Linearization(object):
         from .operators.simple_linear_operators import NullOperator
         return Linearization(field, NullOperator(field.domain, field.domain),
                              want_metric=want_metric)
+
+    @staticmethod
+    def make_partial_var(field, constants, want_metric=False):
+        from .operators.scaling_operator import ScalingOperator
+        from .operators.simple_linear_operators import NullOperator
+        if len(constants) == 0:
+            return Linearization.make_var(field, want_metric)
+        else:
+            ops = [ScalingOperator(0. if key in constants else 1., dom)
+                   for key, dom in field.domain.items()]
+            bdop = BlockDiagonalOperator(fielld.domain, tuple(ops))
+            return Linearization(field, bdop, want_metric=want_metric)
diff --git a/nifty5/minimization/energy_adapter.py b/nifty5/minimization/energy_adapter.py
index 985459cc56f162fe7b5306577d605b598129e4fc..44c421fd5bb81299d311ae0c508f1283fea589cc 100644
--- a/nifty5/minimization/energy_adapter.py
+++ b/nifty5/minimization/energy_adapter.py
@@ -13,14 +13,8 @@ class EnergyAdapter(Energy):
         self._op = op
         self._constants = constants
         self._want_metric = want_metric
-        if len(self._constants) == 0:
-            tmp = self._op(Linearization.make_var(self._position, want_metric))
-        else:
-            ops = [ScalingOperator(0. if key in self._constants else 1., dom)
-                   for key, dom in self._position.domain.items()]
-            bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
-            tmp = self._op(Linearization(self._position, bdop,
-                                         want_metric=want_metric))
+        lin = Linearization.make_partial_var(position, constants, want_metric)
+        tmp = self._op(lin)
         self._val = tmp.val.local_data[()]
         self._grad = tmp.gradient
         self._metric = tmp._metric
diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py
index d7a98364d6024cbada97acca11d0d460864c8280..a98da87f8a3d09d3c501341e2d284ef6b55a4afb 100644
--- a/nifty5/minimization/kl_energy.py
+++ b/nifty5/minimization/kl_energy.py
@@ -9,33 +9,32 @@ from .. import utilities
 
 
 class KL_Energy(Energy):
-    def __init__(self, position, h, nsamp, constants=[], _samples=None,
-                 want_metric=False):
+    def __init__(self, position, h, nsamp, constants=[], _samples=None):
         super(KL_Energy, self).__init__(position)
         self._h = h
         self._constants = constants
-        self._want_metric = want_metric
         if _samples is None:
             met = h(Linearization.make_var(position, True)).metric
             _samples = tuple(met.draw_sample(from_inverse=True)
                              for _ in range(nsamp))
         self._samples = _samples
-        if len(constants) == 0:
-            tmp = Linearization.make_var(position, want_metric)
-        else:
-            ops = [ScalingOperator(0. if key in constants else 1., dom)
-                   for key, dom in position.domain.items()]
-            bdop = BlockDiagonalOperator(position.domain, tuple(ops))
-            tmp = Linearization(position, bdop, want_metric=want_metric)
-        mymap = map(lambda v: self._h(tmp+v), self._samples)
-        tmp = utilities.my_sum(mymap) * (1./len(self._samples))
-        self._val = tmp.val.local_data[()]
-        self._grad = tmp.gradient
-        self._metric = tmp.metric
+
+        self._lin = Linearization.make_partial_var(position, constants)
+        v, g = None, None
+        for s in self._samples:
+            tmp = self._h(self._lin+s)
+            if v is None:
+                v = tmp.val.local_data[()]
+                g = tmp.gradient
+            else:
+                v += tmp.val.local_data[()]
+                g = g + tmp.gradient
+        self._val = v / len(self._samples)
+        self._grad = g * (1./len(self._samples))
+        self._metric = None
 
     def at(self, position):
-        return KL_Energy(position, self._h, 0, self._constants, self._samples,
-                         self._want_metric)
+        return KL_Energy(position, self._h, 0, self._constants, self._samples)
 
     @property
     def value(self):
@@ -45,11 +44,20 @@ class KL_Energy(Energy):
     def gradient(self):
         return self._grad
 
+    def _get_metric(self):
+        if self._metric is None:
+            lin = self._lin.with_want_metric()
+            mymap = map(lambda v: self._h(lin+v).metric, self._samples)
+            self._metric = utilities.my_sum(mymap)
+            self._metric = self._metric.scale(1./len(self._samples))
+
     def apply_metric(self, x):
+        self._get_metric()
         return self._metric(x)
 
     @property
     def metric(self):
+        self._get_metric()
         return self._metric
 
     @property