diff --git a/nifty5/energies/energy_adapter.py b/nifty5/energies/energy_adapter.py index 61256371673ff1078e9f93b035fe34564327b17f..0e09bd9528e38b16cfd3ae2e400a6796bb139281 100644 --- a/nifty5/energies/energy_adapter.py +++ b/nifty5/energies/energy_adapter.py @@ -9,20 +9,31 @@ class EnergyAdapter(Energy): def __init__(self, position, op): super(EnergyAdapter, self).__init__(position) self._op = op - pvar = Linearization.make_var(position) - self._res = op(pvar) + self._val = self._grad = self._metric = None def at(self, position): return EnergyAdapter(position, self._op) + def _fill_all(self): + tmp = self._op(Linearization.make_var(self._position)) + self._val = tmp.val.local_data[()] + self._grad = tmp.gradient + self._metric = tmp.metric + @property def value(self): - return self._res.val.local_data[()] + if self._val is None: + self._val = self._op(self._position) + return self._val @property def gradient(self): - return self._res.gradient + if self._grad is None: + self._fill_all() + return self._grad @property def metric(self): - return self._res.metric + if self._metric is None: + self._fill_all() + return self._metric diff --git a/nifty5/energies/hamiltonian.py b/nifty5/energies/hamiltonian.py index d5cb473828c09d45621472285cf7ab73c9396d8c..48f4729ab768507c9246eaddbbb621b5d987d382 100644 --- a/nifty5/energies/hamiltonian.py +++ b/nifty5/energies/hamiltonian.py @@ -22,6 +22,7 @@ from ..compat import * from ..operators.operator import Operator from ..library.gaussian_energy import GaussianEnergy from ..operators.sampling_enabler import SamplingEnabler +from ..linearization import Linearization class Hamiltonian(Operator): @@ -40,8 +41,7 @@ class Hamiltonian(Operator): return DomainTuple.scalar_domain() def __call__(self, x): - res = self._lh(x) + self._prior(x) - if self._ic_samp is None: + if self._ic_samp is None or not isinstance(x, Linearization): return self._lh(x) + self._prior(x) else: lhx = self._lh(x) diff --git a/nifty5/library/bernoulli_energy.py b/nifty5/library/bernoulli_energy.py index 78041592df677efb24d0c488a8f417ec2e4cc139..f56f47428369fc5915b0e95ce67be46d6a0b9e58 100644 --- a/nifty5/library/bernoulli_energy.py +++ b/nifty5/library/bernoulli_energy.py @@ -22,6 +22,7 @@ from ..compat import * from ..operators.operator import Operator from ..operators.sandwich_operator import SandwichOperator from ..sugar import makeOp +from ..linearization import Linearization class BernoulliEnergy(Operator): @@ -41,6 +42,8 @@ class BernoulliEnergy(Operator): def __call__(self, x): x = self._p(x) v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum() + if not isinstance(x, Linearization): + return v met = makeOp(1./(x.val*(1.-x.val))) met = SandwichOperator.make(x.jac, met) return v.add_metric(met) diff --git a/nifty5/library/gaussian_energy.py b/nifty5/library/gaussian_energy.py index fbb9cf72457d1af997964ec1909ee35d5addabbd..74c9f68f2fcc205562aec823bce95d4499087dac 100644 --- a/nifty5/library/gaussian_energy.py +++ b/nifty5/library/gaussian_energy.py @@ -22,6 +22,7 @@ from ..compat import * from ..operators.operator import Operator from ..operators.sandwich_operator import SandwichOperator from ..domain_tuple import DomainTuple +from ..linearization import Linearization class GaussianEnergy(Operator): @@ -58,5 +59,7 @@ class GaussianEnergy(Operator): residual = x if self._mean is None else x-self._mean icovres = residual if self._icov is None else self._icov(residual) res = .5*(residual*icovres).sum() + if not isinstance(x, Linearization): + return res metric = SandwichOperator.make(x.jac, self._icov) return res.add_metric(metric) diff --git a/nifty5/library/poissonian_energy.py b/nifty5/library/poissonian_energy.py index 14410dc5a4414ce455a8dc6b2bcb262fb4a1cd12..94ba0c54235d3f94d9ee16d46f2d3132e3b61e48 100644 --- a/nifty5/library/poissonian_energy.py +++ b/nifty5/library/poissonian_energy.py @@ -24,6 +24,7 @@ from ..compat import * from ..operators.operator import Operator from ..operators.sandwich_operator import SandwichOperator from ..sugar import makeOp +from ..linearization import Linearization class PoissonianEnergy(Operator): @@ -43,5 +44,7 @@ class PoissonianEnergy(Operator): def __call__(self, x): x = self._op(x) res = (x - self._d*x.log()).sum() + if not isinstance(x, Linearization): + return res metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) return res.add_metric(metric)