diff --git a/nifty5/extra/energy_and_model_tests.py b/nifty5/extra/energy_and_model_tests.py
index 60022c10129a337759b102e3c6622e7518e56692..f151e5756a659ee558128499d9af103c13f93d85 100644
--- a/nifty5/extra/energy_and_model_tests.py
+++ b/nifty5/extra/energy_and_model_tests.py
@@ -41,7 +41,7 @@ def _get_acceptable_location(op, loc, lin):
     for i in range(50):
         try:
             loc2 = loc+dir
-            lin2 = op(Linearization.make_var(loc2))
+            lin2 = op(Linearization.make_var(loc2, lin.want_metric))
             if np.isfinite(lin2.val.sum()) and abs(lin2.val.sum()) < 1e20:
                 break
         except FloatingPointError:
@@ -54,14 +54,14 @@ def _get_acceptable_location(op, loc, lin):
 
 def _check_consistency(op, loc, tol, ntries, do_metric):
     for _ in range(ntries):
-        lin = op(Linearization.make_var(loc))
+        lin = op(Linearization.make_var(loc, do_metric))
         loc2, lin2 = _get_acceptable_location(op, loc, lin)
         dir = loc2-loc
         locnext = loc2
         dirnorm = dir.norm()
         for i in range(50):
             locmid = loc + 0.5*dir
-            linmid = op(Linearization.make_var(locmid))
+            linmid = op(Linearization.make_var(locmid, do_metric))
             dirder = linmid.jac(dir)
             numgrad = (lin2.val-lin.val)
             xtol = tol * dirder.norm() / np.sqrt(dirder.size)
diff --git a/nifty5/library/inverse_gamma_model.py b/nifty5/library/inverse_gamma_model.py
index 9be2acce1a7b4aad7e292849034cca8cd10a60b0..831546453062e734b0e708dd33245e5fd13304e2 100644
--- a/nifty5/library/inverse_gamma_model.py
+++ b/nifty5/library/inverse_gamma_model.py
@@ -53,7 +53,7 @@ class InverseGammaModel(Operator):
         outer = 1/outer_inv
         jac = makeOp(Field.from_local_data(self._domain, inner*outer))
         jac = jac(x.jac)
-        return Linearization(points, jac)
+        return x.new(points, jac)
 
     @staticmethod
     def IG(field, alpha, q):
diff --git a/nifty5/linearization.py b/nifty5/linearization.py
index 41210bdeec24e31147e72080491f16699520aa01..c27dec1abc65794ab57288480f49e59b95053d12 100644
--- a/nifty5/linearization.py
+++ b/nifty5/linearization.py
@@ -9,13 +9,17 @@ from .sugar import makeOp
 
 
 class Linearization(object):
-    def __init__(self, val, jac, metric=None):
+    def __init__(self, val, jac, metric=None, want_metric=False):
         self._val = val
         self._jac = jac
         if self._val.domain != self._jac.target:
             raise ValueError("domain mismatch")
+        self._want_metric = want_metric
         self._metric = metric
 
+    def new(self, val, jac, metric=None):
+        return Linearization(val, jac, metric, self._want_metric)
+
     @property
     def domain(self):
         return self._jac.domain
@@ -37,6 +41,10 @@ class Linearization(object):
         """Only available if target is a scalar"""
         return self._jac.adjoint_times(Field.scalar(1.))
 
+    @property
+    def want_metric(self):
+        return self._want_metric
+
     @property
     def metric(self):
         """Only available if target is a scalar"""
@@ -44,35 +52,34 @@ class Linearization(object):
 
     def __getitem__(self, name):
         from .operators.simple_linear_operators import FieldAdapter
-        return Linearization(self._val[name], FieldAdapter(self.domain, name))
+        return self.new(self._val[name], FieldAdapter(self.domain, name))
 
     def __neg__(self):
-        return Linearization(
-            -self._val, -self._jac,
-            None if self._metric is None else -self._metric)
+        return self.new(-self._val, -self._jac,
+                        None if self._metric is None else -self._metric)
 
     def conjugate(self):
-        return Linearization(
+        return self.new(
             self._val.conjugate(), self._jac.conjugate(),
             None if self._metric is None else self._metric.conjugate())
 
     @property
     def real(self):
-        return Linearization(self._val.real, self._jac.real)
+        return self.new(self._val.real, self._jac.real)
 
     def _myadd(self, other, neg):
         if isinstance(other, Linearization):
             met = None
             if self._metric is not None and other._metric is not None:
                 met = self._metric._myadd(other._metric, neg)
-            return Linearization(
+            return self.new(
                 self._val.flexible_addsub(other._val, neg),
                 self._jac._myadd(other._jac, neg), met)
         if isinstance(other, (int, float, complex, Field, MultiField)):
             if neg:
-                return Linearization(self._val-other, self._jac, self._metric)
+                return self.new(self._val-other, self._jac, self._metric)
             else:
-                return Linearization(self._val+other, self._jac, self._metric)
+                return self.new(self._val+other, self._jac, self._metric)
 
     def __add__(self, other):
         return self._myadd(other, False)
@@ -91,7 +98,7 @@ class Linearization(object):
         if isinstance(other, Linearization):
             if self.target != other.target:
                 raise ValueError("domain mismatch")
-            return Linearization(
+            return self.new(
                 self._val*other._val,
                 (makeOp(other._val)(self._jac))._myadd(
                  makeOp(self._val)(other._jac), False))
@@ -99,11 +106,11 @@ class Linearization(object):
             if other == 1:
                 return self
             met = None if self._metric is None else self._metric.scale(other)
-            return Linearization(self._val*other, self._jac.scale(other), met)
+            return self.new(self._val*other, self._jac.scale(other), met)
         if isinstance(other, (Field, MultiField)):
             if self.target != other.domain:
                 raise ValueError("domain mismatch")
-            return Linearization(self._val*other, makeOp(other)(self._jac))
+            return self.new(self._val*other, makeOp(other)(self._jac))
 
     def __rmul__(self, other):
         return self.__mul__(other)
@@ -111,46 +118,48 @@ class Linearization(object):
     def vdot(self, other):
         from .operators.simple_linear_operators import VdotOperator
         if isinstance(other, (Field, MultiField)):
-            return Linearization(
+            return self.new(
                 Field.scalar(self._val.vdot(other)),
                 VdotOperator(other)(self._jac))
-        return Linearization(
+        return self.new(
             Field.scalar(self._val.vdot(other._val)),
             VdotOperator(self._val)(other._jac) +
             VdotOperator(other._val)(self._jac))
 
     def sum(self):
         from .operators.simple_linear_operators import SumReductionOperator
-        return Linearization(
+        return self.new(
             Field.scalar(self._val.sum()),
             SumReductionOperator(self._jac.target)(self._jac))
 
     def exp(self):
         tmp = self._val.exp()
-        return Linearization(tmp, makeOp(tmp)(self._jac))
+        return self.new(tmp, makeOp(tmp)(self._jac))
 
     def log(self):
         tmp = self._val.log()
-        return Linearization(tmp, makeOp(1./self._val)(self._jac))
+        return self.new(tmp, makeOp(1./self._val)(self._jac))
 
     def tanh(self):
         tmp = self._val.tanh()
-        return Linearization(tmp, makeOp(1.-tmp**2)(self._jac))
+        return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
 
     def positive_tanh(self):
         tmp = self._val.tanh()
         tmp2 = 0.5*(1.+tmp)
-        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
+        return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
 
     def add_metric(self, metric):
-        return Linearization(self._val, self._jac, metric)
+        return self.new(self._val, self._jac, metric)
 
     @staticmethod
-    def make_var(field):
+    def make_var(field, want_metric=False):
         from .operators.scaling_operator import ScalingOperator
-        return Linearization(field, ScalingOperator(1., field.domain))
+        return Linearization(field, ScalingOperator(1., field.domain),
+                             want_metric=want_metric)
 
     @staticmethod
-    def make_const(field):
+    def make_const(field, want_metric=False):
         from .operators.simple_linear_operators import NullOperator
-        return Linearization(field, NullOperator(field.domain, field.domain))
+        return Linearization(field, NullOperator(field.domain, field.domain),
+                             want_metric=want_metric)
diff --git a/nifty5/minimization/energy_adapter.py b/nifty5/minimization/energy_adapter.py
index 7d5c0e29684812cb8ce203122b301c038a8d0621..d0e71a96fd4c011d35fb4433bb9b8afe0824c2cd 100644
--- a/nifty5/minimization/energy_adapter.py
+++ b/nifty5/minimization/energy_adapter.py
@@ -8,17 +8,18 @@ from ..operators.scaling_operator import ScalingOperator
 
 
 class EnergyAdapter(Energy):
-    def __init__(self, position, op, constants=[]):
+    def __init__(self, position, op, constants=[], want_metric=False):
         super(EnergyAdapter, self).__init__(position)
         self._op = op
         self._constants = constants
         if len(self._constants) == 0:
-            tmp = self._op(Linearization.make_var(self._position))
+            tmp = self._op(Linearization.make_var(self._position, want_metric))
         else:
             ops = [ScalingOperator(0. if key in self._constants else 1., dom)
                    for key, dom in self._position.domain.items()]
             bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
-            tmp = self._op(Linearization(self._position, bdop))
+            tmp = self._op(Linearization(self._position, bdop,
+                                         want_metric=want_metric))
         self._val = tmp.val.local_data[()]
         self._grad = tmp.gradient
         self._metric = tmp._metric
diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py
index d21dbdafe6ca3358fc064b12986df6d0abe406f2..b00e7a46c47dec305f670f192094430a23688218 100644
--- a/nifty5/minimization/kl_energy.py
+++ b/nifty5/minimization/kl_energy.py
@@ -9,22 +9,24 @@ from .. import utilities
 
 
 class KL_Energy(Energy):
-    def __init__(self, position, h, nsamp, constants=[], _samples=None):
+    def __init__(self, position, h, nsamp, constants=[], _samples=None,
+                 want_metric=False):
         super(KL_Energy, self).__init__(position)
         self._h = h
         self._constants = constants
+        self._want_metric = want_metric
         if _samples is None:
-            met = h(Linearization.make_var(position)).metric
+            met = h(Linearization.make_var(position, True)).metric
             _samples = tuple(met.draw_sample(from_inverse=True)
                              for _ in range(nsamp))
         self._samples = _samples
         if len(constants) == 0:
-            tmp = Linearization.make_var(position)
+            tmp = Linearization.make_var(position, want_metric)
         else:
             ops = [ScalingOperator(0. if key in constants else 1., dom)
                    for key, dom in position.domain.items()]
             bdop = BlockDiagonalOperator(position.domain, tuple(ops))
-            tmp = Linearization(position, bdop)
+            tmp = Linearization(position, bdop, want_metric=want_metric)
         mymap = map(lambda v: self._h(tmp+v), self._samples)
         tmp = utilities.my_sum(mymap) * (1./len(self._samples))
         self._val = tmp.val.local_data[()]
@@ -32,7 +34,8 @@ class KL_Energy(Energy):
         self._metric = tmp.metric
 
     def at(self, position):
-        return KL_Energy(position, self._h, 0, self._constants, self._samples)
+        return KL_Energy(position, self._h, 0, self._constants, self._samples,
+                         self._want_metric)
 
     @property
     def value(self):
diff --git a/nifty5/operators/energy_operators.py b/nifty5/operators/energy_operators.py
index 7d502421e295b5f3f79697b9cf187ac202e56440..a68ea92527a4850bd79976cb25eab87d2e8d9d7c 100644
--- a/nifty5/operators/energy_operators.py
+++ b/nifty5/operators/energy_operators.py
@@ -42,7 +42,7 @@ class SquaredNormOperator(EnergyOperator):
         if isinstance(x, Linearization):
             val = Field.scalar(x.val.vdot(x.val))
             jac = VdotOperator(2*x.val)(x.jac)
-            return Linearization(val, jac)
+            return x.new(val, jac)
         return Field.scalar(x.vdot(x))
 
 
@@ -59,7 +59,7 @@ class QuadraticFormOperator(EnergyOperator):
             t1 = self._op(x.val)
             jac = VdotOperator(t1)(x.jac)
             val = Field.scalar(0.5*x.val.vdot(t1))
-            return Linearization(val, jac)
+            return x.new(val, jac)
         return Field.scalar(0.5*x.vdot(self._op(x)))
 
 
@@ -91,7 +91,7 @@ class GaussianEnergy(EnergyOperator):
     def apply(self, x):
         residual = x if self._mean is None else x-self._mean
         res = self._op(residual).real
-        if not isinstance(x, Linearization):
+        if not isinstance(x, Linearization) or not x.want_metric:
             return res
         metric = SandwichOperator.make(x.jac, self._icov)
         return res.add_metric(metric)
@@ -107,6 +107,8 @@ class PoissonianEnergy(EnergyOperator):
         res = x.sum() - x.log().vdot(self._d)
         if not isinstance(x, Linearization):
             return Field.scalar(res)
+        if not x.want_metric:
+            return res
         metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
         return res.add_metric(metric)
 
@@ -122,6 +124,8 @@ class BernoulliEnergy(EnergyOperator):
         v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
         if not isinstance(x, Linearization):
             return Field.scalar(v)
+        if not x.want_metric:
+            return v
         met = makeOp(1./(x.val*(1.-x.val)))
         met = SandwichOperator.make(x.jac, met)
         return v.add_metric(met)
@@ -135,11 +139,11 @@ class Hamiltonian(EnergyOperator):
         self._domain = lh.domain
 
     def apply(self, x):
-        if self._ic_samp is None or not isinstance(x, Linearization):
+        if (self._ic_samp is None or not isinstance(x, Linearization) or
+                not x.want_metric):
             return self._lh(x)+self._prior(x)
         else:
-            lhx = self._lh(x)
-            prx = self._prior(x)
+            lhx, prx = self._lh(x), self._prior(x)
             mtr = SamplingEnabler(lhx.metric, prx.metric.inverse,
                                   self._ic_samp, prx.metric.inverse)
             return (lhx+prx).add_metric(mtr)
diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py
index 26a0cf81dc0dfac484ce06f2ba567158198c09c6..aa330c2f4ccede83ad73de5b72810da71ed542ba 100644
--- a/nifty5/operators/linear_operator.py
+++ b/nifty5/operators/linear_operator.py
@@ -175,7 +175,7 @@ class LinearOperator(Operator):
             return self.apply(x, self.TIMES)
         from ..linearization import Linearization
         if isinstance(x, Linearization):
-            return Linearization(self(x._val), self(x._jac))
+            return x.new(self(x._val), self(x._jac))
         return self.__matmul__(x)
 
     def times(self, x):
diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py
index 73ece15e64ef3e1f64b8d8e27f75e09f4ced4f89..13b7e6fb7bfb1e27f12e3f292744b239f868b8b5 100644
--- a/nifty5/operators/operator.py
+++ b/nifty5/operators/operator.py
@@ -144,11 +144,12 @@ class _OpProd(Operator):
         v2 = v.extract(self._op2.domain)
         if not lin:
             return self._op1(v1) * self._op2(v2)
-        lin1 = self._op1(Linearization.make_var(v1))
-        lin2 = self._op2(Linearization.make_var(v2))
+        wm = x.want_metric
+        lin1 = self._op1(Linearization.make_var(v1, wm))
+        lin2 = self._op2(Linearization.make_var(v2, wm))
         op = (makeOp(lin1._val)(lin2._jac))._myadd(
             makeOp(lin2._val)(lin1._jac), False)
-        return Linearization(lin1._val*lin2._val, op(x.jac))
+        return lin1.new(lin1._val*lin2._val, op(x.jac))
 
 
 class _OpSum(Operator):
@@ -168,10 +169,11 @@ class _OpSum(Operator):
         res = None
         if not lin:
             return self._op1(v1).unite(self._op2(v2))
-        lin1 = self._op1(Linearization.make_var(v1))
-        lin2 = self._op2(Linearization.make_var(v2))
+        wm = x.want_metric
+        lin1 = self._op1(Linearization.make_var(v1, wm))
+        lin2 = self._op2(Linearization.make_var(v2, wm))
         op = lin1._jac._myadd(lin2._jac, False)
-        res = Linearization(lin1._val+lin2._val, op(x.jac))
+        res = lin1.new(lin1._val+lin2._val, op(x.jac))
         if lin1._metric is not None and lin2._metric is not None:
             res = res.add_metric(lin1._metric + lin2._metric)
         return res