diff --git a/nifty5/extra/energy_and_model_tests.py b/nifty5/extra/energy_and_model_tests.py index 60022c10129a337759b102e3c6622e7518e56692..f151e5756a659ee558128499d9af103c13f93d85 100644 --- a/nifty5/extra/energy_and_model_tests.py +++ b/nifty5/extra/energy_and_model_tests.py @@ -41,7 +41,7 @@ def _get_acceptable_location(op, loc, lin): for i in range(50): try: loc2 = loc+dir - lin2 = op(Linearization.make_var(loc2)) + lin2 = op(Linearization.make_var(loc2, lin.want_metric)) if np.isfinite(lin2.val.sum()) and abs(lin2.val.sum()) < 1e20: break except FloatingPointError: @@ -54,14 +54,14 @@ def _get_acceptable_location(op, loc, lin): def _check_consistency(op, loc, tol, ntries, do_metric): for _ in range(ntries): - lin = op(Linearization.make_var(loc)) + lin = op(Linearization.make_var(loc, do_metric)) loc2, lin2 = _get_acceptable_location(op, loc, lin) dir = loc2-loc locnext = loc2 dirnorm = dir.norm() for i in range(50): locmid = loc + 0.5*dir - linmid = op(Linearization.make_var(locmid)) + linmid = op(Linearization.make_var(locmid, do_metric)) dirder = linmid.jac(dir) numgrad = (lin2.val-lin.val) xtol = tol * dirder.norm() / np.sqrt(dirder.size) diff --git a/nifty5/library/inverse_gamma_model.py b/nifty5/library/inverse_gamma_model.py index 9be2acce1a7b4aad7e292849034cca8cd10a60b0..831546453062e734b0e708dd33245e5fd13304e2 100644 --- a/nifty5/library/inverse_gamma_model.py +++ b/nifty5/library/inverse_gamma_model.py @@ -53,7 +53,7 @@ class InverseGammaModel(Operator): outer = 1/outer_inv jac = makeOp(Field.from_local_data(self._domain, inner*outer)) jac = jac(x.jac) - return Linearization(points, jac) + return x.new(points, jac) @staticmethod def IG(field, alpha, q): diff --git a/nifty5/linearization.py b/nifty5/linearization.py index 41210bdeec24e31147e72080491f16699520aa01..c27dec1abc65794ab57288480f49e59b95053d12 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -9,13 +9,17 @@ from .sugar import makeOp class Linearization(object): - def __init__(self, val, jac, metric=None): + def __init__(self, val, jac, metric=None, want_metric=False): self._val = val self._jac = jac if self._val.domain != self._jac.target: raise ValueError("domain mismatch") + self._want_metric = want_metric self._metric = metric + def new(self, val, jac, metric=None): + return Linearization(val, jac, metric, self._want_metric) + @property def domain(self): return self._jac.domain @@ -37,6 +41,10 @@ class Linearization(object): """Only available if target is a scalar""" return self._jac.adjoint_times(Field.scalar(1.)) + @property + def want_metric(self): + return self._want_metric + @property def metric(self): """Only available if target is a scalar""" @@ -44,35 +52,34 @@ class Linearization(object): def __getitem__(self, name): from .operators.simple_linear_operators import FieldAdapter - return Linearization(self._val[name], FieldAdapter(self.domain, name)) + return self.new(self._val[name], FieldAdapter(self.domain, name)) def __neg__(self): - return Linearization( - -self._val, -self._jac, - None if self._metric is None else -self._metric) + return self.new(-self._val, -self._jac, + None if self._metric is None else -self._metric) def conjugate(self): - return Linearization( + return self.new( self._val.conjugate(), self._jac.conjugate(), None if self._metric is None else self._metric.conjugate()) @property def real(self): - return Linearization(self._val.real, self._jac.real) + return self.new(self._val.real, self._jac.real) def _myadd(self, other, neg): if isinstance(other, Linearization): met = None if self._metric is not None and other._metric is not None: met = self._metric._myadd(other._metric, neg) - return Linearization( + return self.new( self._val.flexible_addsub(other._val, neg), self._jac._myadd(other._jac, neg), met) if isinstance(other, (int, float, complex, Field, MultiField)): if neg: - return Linearization(self._val-other, self._jac, self._metric) + return self.new(self._val-other, self._jac, self._metric) else: - return Linearization(self._val+other, self._jac, self._metric) + return self.new(self._val+other, self._jac, self._metric) def __add__(self, other): return self._myadd(other, False) @@ -91,7 +98,7 @@ class Linearization(object): if isinstance(other, Linearization): if self.target != other.target: raise ValueError("domain mismatch") - return Linearization( + return self.new( self._val*other._val, (makeOp(other._val)(self._jac))._myadd( makeOp(self._val)(other._jac), False)) @@ -99,11 +106,11 @@ class Linearization(object): if other == 1: return self met = None if self._metric is None else self._metric.scale(other) - return Linearization(self._val*other, self._jac.scale(other), met) + return self.new(self._val*other, self._jac.scale(other), met) if isinstance(other, (Field, MultiField)): if self.target != other.domain: raise ValueError("domain mismatch") - return Linearization(self._val*other, makeOp(other)(self._jac)) + return self.new(self._val*other, makeOp(other)(self._jac)) def __rmul__(self, other): return self.__mul__(other) @@ -111,46 +118,48 @@ class Linearization(object): def vdot(self, other): from .operators.simple_linear_operators import VdotOperator if isinstance(other, (Field, MultiField)): - return Linearization( + return self.new( Field.scalar(self._val.vdot(other)), VdotOperator(other)(self._jac)) - return Linearization( + return self.new( Field.scalar(self._val.vdot(other._val)), VdotOperator(self._val)(other._jac) + VdotOperator(other._val)(self._jac)) def sum(self): from .operators.simple_linear_operators import SumReductionOperator - return Linearization( + return self.new( Field.scalar(self._val.sum()), SumReductionOperator(self._jac.target)(self._jac)) def exp(self): tmp = self._val.exp() - return Linearization(tmp, makeOp(tmp)(self._jac)) + return self.new(tmp, makeOp(tmp)(self._jac)) def log(self): tmp = self._val.log() - return Linearization(tmp, makeOp(1./self._val)(self._jac)) + return self.new(tmp, makeOp(1./self._val)(self._jac)) def tanh(self): tmp = self._val.tanh() - return Linearization(tmp, makeOp(1.-tmp**2)(self._jac)) + return self.new(tmp, makeOp(1.-tmp**2)(self._jac)) def positive_tanh(self): tmp = self._val.tanh() tmp2 = 0.5*(1.+tmp) - return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) + return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) def add_metric(self, metric): - return Linearization(self._val, self._jac, metric) + return self.new(self._val, self._jac, metric) @staticmethod - def make_var(field): + def make_var(field, want_metric=False): from .operators.scaling_operator import ScalingOperator - return Linearization(field, ScalingOperator(1., field.domain)) + return Linearization(field, ScalingOperator(1., field.domain), + want_metric=want_metric) @staticmethod - def make_const(field): + def make_const(field, want_metric=False): from .operators.simple_linear_operators import NullOperator - return Linearization(field, NullOperator(field.domain, field.domain)) + return Linearization(field, NullOperator(field.domain, field.domain), + want_metric=want_metric) diff --git a/nifty5/minimization/energy_adapter.py b/nifty5/minimization/energy_adapter.py index 7d5c0e29684812cb8ce203122b301c038a8d0621..d0e71a96fd4c011d35fb4433bb9b8afe0824c2cd 100644 --- a/nifty5/minimization/energy_adapter.py +++ b/nifty5/minimization/energy_adapter.py @@ -8,17 +8,18 @@ from ..operators.scaling_operator import ScalingOperator class EnergyAdapter(Energy): - def __init__(self, position, op, constants=[]): + def __init__(self, position, op, constants=[], want_metric=False): super(EnergyAdapter, self).__init__(position) self._op = op self._constants = constants if len(self._constants) == 0: - tmp = self._op(Linearization.make_var(self._position)) + tmp = self._op(Linearization.make_var(self._position, want_metric)) else: ops = [ScalingOperator(0. if key in self._constants else 1., dom) for key, dom in self._position.domain.items()] bdop = BlockDiagonalOperator(self._position.domain, tuple(ops)) - tmp = self._op(Linearization(self._position, bdop)) + tmp = self._op(Linearization(self._position, bdop, + want_metric=want_metric)) self._val = tmp.val.local_data[()] self._grad = tmp.gradient self._metric = tmp._metric diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py index d21dbdafe6ca3358fc064b12986df6d0abe406f2..b00e7a46c47dec305f670f192094430a23688218 100644 --- a/nifty5/minimization/kl_energy.py +++ b/nifty5/minimization/kl_energy.py @@ -9,22 +9,24 @@ from .. import utilities class KL_Energy(Energy): - def __init__(self, position, h, nsamp, constants=[], _samples=None): + def __init__(self, position, h, nsamp, constants=[], _samples=None, + want_metric=False): super(KL_Energy, self).__init__(position) self._h = h self._constants = constants + self._want_metric = want_metric if _samples is None: - met = h(Linearization.make_var(position)).metric + met = h(Linearization.make_var(position, True)).metric _samples = tuple(met.draw_sample(from_inverse=True) for _ in range(nsamp)) self._samples = _samples if len(constants) == 0: - tmp = Linearization.make_var(position) + tmp = Linearization.make_var(position, want_metric) else: ops = [ScalingOperator(0. if key in constants else 1., dom) for key, dom in position.domain.items()] bdop = BlockDiagonalOperator(position.domain, tuple(ops)) - tmp = Linearization(position, bdop) + tmp = Linearization(position, bdop, want_metric=want_metric) mymap = map(lambda v: self._h(tmp+v), self._samples) tmp = utilities.my_sum(mymap) * (1./len(self._samples)) self._val = tmp.val.local_data[()] @@ -32,7 +34,8 @@ class KL_Energy(Energy): self._metric = tmp.metric def at(self, position): - return KL_Energy(position, self._h, 0, self._constants, self._samples) + return KL_Energy(position, self._h, 0, self._constants, self._samples, + self._want_metric) @property def value(self): diff --git a/nifty5/operators/energy_operators.py b/nifty5/operators/energy_operators.py index 7d502421e295b5f3f79697b9cf187ac202e56440..a68ea92527a4850bd79976cb25eab87d2e8d9d7c 100644 --- a/nifty5/operators/energy_operators.py +++ b/nifty5/operators/energy_operators.py @@ -42,7 +42,7 @@ class SquaredNormOperator(EnergyOperator): if isinstance(x, Linearization): val = Field.scalar(x.val.vdot(x.val)) jac = VdotOperator(2*x.val)(x.jac) - return Linearization(val, jac) + return x.new(val, jac) return Field.scalar(x.vdot(x)) @@ -59,7 +59,7 @@ class QuadraticFormOperator(EnergyOperator): t1 = self._op(x.val) jac = VdotOperator(t1)(x.jac) val = Field.scalar(0.5*x.val.vdot(t1)) - return Linearization(val, jac) + return x.new(val, jac) return Field.scalar(0.5*x.vdot(self._op(x))) @@ -91,7 +91,7 @@ class GaussianEnergy(EnergyOperator): def apply(self, x): residual = x if self._mean is None else x-self._mean res = self._op(residual).real - if not isinstance(x, Linearization): + if not isinstance(x, Linearization) or not x.want_metric: return res metric = SandwichOperator.make(x.jac, self._icov) return res.add_metric(metric) @@ -107,6 +107,8 @@ class PoissonianEnergy(EnergyOperator): res = x.sum() - x.log().vdot(self._d) if not isinstance(x, Linearization): return Field.scalar(res) + if not x.want_metric: + return res metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) return res.add_metric(metric) @@ -122,6 +124,8 @@ class BernoulliEnergy(EnergyOperator): v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) if not isinstance(x, Linearization): return Field.scalar(v) + if not x.want_metric: + return v met = makeOp(1./(x.val*(1.-x.val))) met = SandwichOperator.make(x.jac, met) return v.add_metric(met) @@ -135,11 +139,11 @@ class Hamiltonian(EnergyOperator): self._domain = lh.domain def apply(self, x): - if self._ic_samp is None or not isinstance(x, Linearization): + if (self._ic_samp is None or not isinstance(x, Linearization) or + not x.want_metric): return self._lh(x)+self._prior(x) else: - lhx = self._lh(x) - prx = self._prior(x) + lhx, prx = self._lh(x), self._prior(x) mtr = SamplingEnabler(lhx.metric, prx.metric.inverse, self._ic_samp, prx.metric.inverse) return (lhx+prx).add_metric(mtr) diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index 26a0cf81dc0dfac484ce06f2ba567158198c09c6..aa330c2f4ccede83ad73de5b72810da71ed542ba 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -175,7 +175,7 @@ class LinearOperator(Operator): return self.apply(x, self.TIMES) from ..linearization import Linearization if isinstance(x, Linearization): - return Linearization(self(x._val), self(x._jac)) + return x.new(self(x._val), self(x._jac)) return self.__matmul__(x) def times(self, x): diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 73ece15e64ef3e1f64b8d8e27f75e09f4ced4f89..13b7e6fb7bfb1e27f12e3f292744b239f868b8b5 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -144,11 +144,12 @@ class _OpProd(Operator): v2 = v.extract(self._op2.domain) if not lin: return self._op1(v1) * self._op2(v2) - lin1 = self._op1(Linearization.make_var(v1)) - lin2 = self._op2(Linearization.make_var(v2)) + wm = x.want_metric + lin1 = self._op1(Linearization.make_var(v1, wm)) + lin2 = self._op2(Linearization.make_var(v2, wm)) op = (makeOp(lin1._val)(lin2._jac))._myadd( makeOp(lin2._val)(lin1._jac), False) - return Linearization(lin1._val*lin2._val, op(x.jac)) + return lin1.new(lin1._val*lin2._val, op(x.jac)) class _OpSum(Operator): @@ -168,10 +169,11 @@ class _OpSum(Operator): res = None if not lin: return self._op1(v1).unite(self._op2(v2)) - lin1 = self._op1(Linearization.make_var(v1)) - lin2 = self._op2(Linearization.make_var(v2)) + wm = x.want_metric + lin1 = self._op1(Linearization.make_var(v1, wm)) + lin2 = self._op2(Linearization.make_var(v2, wm)) op = lin1._jac._myadd(lin2._jac, False) - res = Linearization(lin1._val+lin2._val, op(x.jac)) + res = lin1.new(lin1._val+lin2._val, op(x.jac)) if lin1._metric is not None and lin2._metric is not None: res = res.add_metric(lin1._metric + lin2._metric) return res