Commit 7a201726 authored by Martin Reinecke's avatar Martin Reinecke

performance tweaks

parent 08151150
......@@ -9,20 +9,31 @@ class EnergyAdapter(Energy):
def __init__(self, position, op):
super(EnergyAdapter, self).__init__(position)
self._op = op
pvar = Linearization.make_var(position)
self._res = op(pvar)
self._val = self._grad = self._metric = None
def at(self, position):
return EnergyAdapter(position, self._op)
def _fill_all(self):
tmp = self._op(Linearization.make_var(self._position))
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp.metric
@property
def value(self):
return self._res.val.local_data[()]
if self._val is None:
self._val = self._op(self._position)
return self._val
@property
def gradient(self):
return self._res.gradient
if self._grad is None:
self._fill_all()
return self._grad
@property
def metric(self):
return self._res.metric
if self._metric is None:
self._fill_all()
return self._metric
......@@ -22,6 +22,7 @@ from ..compat import *
from ..operators.operator import Operator
from ..library.gaussian_energy import GaussianEnergy
from ..operators.sampling_enabler import SamplingEnabler
from ..linearization import Linearization
class Hamiltonian(Operator):
......@@ -40,8 +41,7 @@ class Hamiltonian(Operator):
return DomainTuple.scalar_domain()
def __call__(self, x):
res = self._lh(x) + self._prior(x)
if self._ic_samp is None:
if self._ic_samp is None or not isinstance(x, Linearization):
return self._lh(x) + self._prior(x)
else:
lhx = self._lh(x)
......
......@@ -22,6 +22,7 @@ from ..compat import *
from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp
from ..linearization import Linearization
class BernoulliEnergy(Operator):
......@@ -41,6 +42,8 @@ class BernoulliEnergy(Operator):
def __call__(self, x):
x = self._p(x)
v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum()
if not isinstance(x, Linearization):
return v
met = makeOp(1./(x.val*(1.-x.val)))
met = SandwichOperator.make(x.jac, met)
return v.add_metric(met)
......@@ -22,6 +22,7 @@ from ..compat import *
from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator
from ..domain_tuple import DomainTuple
from ..linearization import Linearization
class GaussianEnergy(Operator):
......@@ -58,5 +59,7 @@ class GaussianEnergy(Operator):
residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual)
res = .5*(residual*icovres).sum()
if not isinstance(x, Linearization):
return res
metric = SandwichOperator.make(x.jac, self._icov)
return res.add_metric(metric)
......@@ -24,6 +24,7 @@ from ..compat import *
from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp
from ..linearization import Linearization
class PoissonianEnergy(Operator):
......@@ -43,5 +44,7 @@ class PoissonianEnergy(Operator):
def __call__(self, x):
x = self._op(x)
res = (x - self._d*x.log()).sum()
if not isinstance(x, Linearization):
return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment