Commit 62d56279 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

peformance tweaks

parent e8079ae3
......@@ -41,7 +41,7 @@ class BernoulliEnergy(Operator):
def __call__(self, x):
x = self._p(x)
v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum()
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization):
return v
met = makeOp(1./(x.val*(1.-x.val)))
......
......@@ -58,7 +58,7 @@ class GaussianEnergy(Operator):
def __call__(self, x):
residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual)
res = .5*(residual*icovres).sum()
res = .5*residual.vdot(icovres)
if not isinstance(x, Linearization):
return res
metric = SandwichOperator.make(x.jac, self._icov)
......
......@@ -43,7 +43,7 @@ class PoissonianEnergy(Operator):
def __call__(self, x):
x = self._op(x)
res = (x - self._d*x.log()).sum()
res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization):
return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
......
......@@ -33,8 +33,7 @@ class Linearization(object):
@property
def gradient(self):
"""Only available if target is a scalar"""
from .sugar import full
return self._jac.adjoint_times(full(self._jac.target, 1.))
return self._jac.adjoint_times(Field(self._jac.target, 1.))
@property
def metric(self):
......@@ -73,14 +72,12 @@ class Linearization(object):
def __mul__(self, other):
from .sugar import makeOp
from .operators.relaxed_sum_operator import RelaxedSumOperator
if isinstance(other, Linearization):
d1 = makeOp(self._val)
d2 = makeOp(other._val)
return Linearization(
self._val*other._val,
RelaxedSumOperator((d2.chain(self._jac),
d1.chain(other._jac))))
d2.chain(self._jac) + d1.chain(other._jac))
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
......@@ -99,12 +96,25 @@ class Linearization(object):
d1 = makeOp(other)
return Linearization(self._val*other, d1.chain(self._jac))
def vdot(self, other):
from .domain_tuple import DomainTuple
from .operators.vdot_operator import VdotOperator
if isinstance(other, (Field, MultiField)):
return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
VdotOperator(other).chain(self._jac))
return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)),
VdotOperator(self._val).chain(other._jac) +
VdotOperator(other._val).chain(self._jac))
def sum(self):
from .domain_tuple import DomainTuple
from .operators.vdot_operator import SumReductionOperator
from .sugar import full
from .operators.vdot_operator import VdotOperator
return Linearization(
full((), self._val.sum()),
VdotOperator(full(self._jac.target, 1)).chain(self._jac))
Field(DomainTuple.scalar_domain(), self._val.sum()),
SumReductionOperator(self._jac.target).chain(self._jac))
def exp(self):
tmp = self._val.exp()
......
......@@ -25,13 +25,14 @@ from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from .linear_operator import LinearOperator
from ..sugar import full
from ..field import Field
class VdotOperator(LinearOperator):
def __init__(self, field):
super(VdotOperator, self).__init__()
self._field = field
self._target = DomainTuple.make(())
self._target = DomainTuple.scalar_domain()
@property
def domain(self):
......@@ -48,5 +49,30 @@ class VdotOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return full(self._target, self._field.vdot(x))
return Field(self._target, self._field.vdot(x))
return self._field*x.local_data[()]
class SumReductionOperator(LinearOperator):
def __init__(self, domain):
super(SumReductionOperator, self).__init__()
self._domain = domain
self._target = DomainTuple.scalar_domain()
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field(self._target, x.sum())
return full(self._domain, x.local_data[()])
......@@ -73,4 +73,4 @@ class Energy_Tests(unittest.TestCase):
energy, xi0, ntries=10)
else:
ift.extra.check_value_gradient_consistency(
energy, xi0, ntries=10)
energy, xi0, ntries=10, tol=5e-8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment