Commit 7a201726 authored by Martin Reinecke's avatar Martin Reinecke

performance tweaks

parent 08151150
...@@ -9,20 +9,31 @@ class EnergyAdapter(Energy): ...@@ -9,20 +9,31 @@ class EnergyAdapter(Energy):
def __init__(self, position, op): def __init__(self, position, op):
super(EnergyAdapter, self).__init__(position) super(EnergyAdapter, self).__init__(position)
self._op = op self._op = op
pvar = Linearization.make_var(position) self._val = self._grad = self._metric = None
self._res = op(pvar)
def at(self, position): def at(self, position):
return EnergyAdapter(position, self._op) return EnergyAdapter(position, self._op)
def _fill_all(self):
tmp = self._op(Linearization.make_var(self._position))
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp.metric
@property @property
def value(self): def value(self):
return self._res.val.local_data[()] if self._val is None:
self._val = self._op(self._position)
return self._val
@property @property
def gradient(self): def gradient(self):
return self._res.gradient if self._grad is None:
self._fill_all()
return self._grad
@property @property
def metric(self): def metric(self):
return self._res.metric if self._metric is None:
self._fill_all()
return self._metric
...@@ -22,6 +22,7 @@ from ..compat import * ...@@ -22,6 +22,7 @@ from ..compat import *
from ..operators.operator import Operator from ..operators.operator import Operator
from ..library.gaussian_energy import GaussianEnergy from ..library.gaussian_energy import GaussianEnergy
from ..operators.sampling_enabler import SamplingEnabler from ..operators.sampling_enabler import SamplingEnabler
from ..linearization import Linearization
class Hamiltonian(Operator): class Hamiltonian(Operator):
...@@ -40,8 +41,7 @@ class Hamiltonian(Operator): ...@@ -40,8 +41,7 @@ class Hamiltonian(Operator):
return DomainTuple.scalar_domain() return DomainTuple.scalar_domain()
def __call__(self, x): def __call__(self, x):
res = self._lh(x) + self._prior(x) if self._ic_samp is None or not isinstance(x, Linearization):
if self._ic_samp is None:
return self._lh(x) + self._prior(x) return self._lh(x) + self._prior(x)
else: else:
lhx = self._lh(x) lhx = self._lh(x)
......
...@@ -22,6 +22,7 @@ from ..compat import * ...@@ -22,6 +22,7 @@ from ..compat import *
from ..operators.operator import Operator from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp from ..sugar import makeOp
from ..linearization import Linearization
class BernoulliEnergy(Operator): class BernoulliEnergy(Operator):
...@@ -41,6 +42,8 @@ class BernoulliEnergy(Operator): ...@@ -41,6 +42,8 @@ class BernoulliEnergy(Operator):
def __call__(self, x): def __call__(self, x):
x = self._p(x) x = self._p(x)
v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum() v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum()
if not isinstance(x, Linearization):
return v
met = makeOp(1./(x.val*(1.-x.val))) met = makeOp(1./(x.val*(1.-x.val)))
met = SandwichOperator.make(x.jac, met) met = SandwichOperator.make(x.jac, met)
return v.add_metric(met) return v.add_metric(met)
...@@ -22,6 +22,7 @@ from ..compat import * ...@@ -22,6 +22,7 @@ from ..compat import *
from ..operators.operator import Operator from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..linearization import Linearization
class GaussianEnergy(Operator): class GaussianEnergy(Operator):
...@@ -58,5 +59,7 @@ class GaussianEnergy(Operator): ...@@ -58,5 +59,7 @@ class GaussianEnergy(Operator):
residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual) icovres = residual if self._icov is None else self._icov(residual)
res = .5*(residual*icovres).sum() res = .5*(residual*icovres).sum()
if not isinstance(x, Linearization):
return res
metric = SandwichOperator.make(x.jac, self._icov) metric = SandwichOperator.make(x.jac, self._icov)
return res.add_metric(metric) return res.add_metric(metric)
...@@ -24,6 +24,7 @@ from ..compat import * ...@@ -24,6 +24,7 @@ from ..compat import *
from ..operators.operator import Operator from ..operators.operator import Operator
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..sugar import makeOp from ..sugar import makeOp
from ..linearization import Linearization
class PoissonianEnergy(Operator): class PoissonianEnergy(Operator):
...@@ -43,5 +44,7 @@ class PoissonianEnergy(Operator): ...@@ -43,5 +44,7 @@ class PoissonianEnergy(Operator):
def __call__(self, x): def __call__(self, x):
x = self._op(x) x = self._op(x)
res = (x - self._d*x.log()).sum() res = (x - self._d*x.log()).sum()
if not isinstance(x, Linearization):
return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric) return res.add_metric(metric)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment