Commit 62d56279 by Martin Reinecke

### peformance tweaks

parent e8079ae3
 ... ... @@ -41,7 +41,7 @@ class BernoulliEnergy(Operator): def __call__(self, x): x = self._p(x) v = ((-self._d)*x.log()).sum() - ((1.-self._d)*((1.-x).log())).sum() v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) if not isinstance(x, Linearization): return v met = makeOp(1./(x.val*(1.-x.val))) ... ...
 ... ... @@ -58,7 +58,7 @@ class GaussianEnergy(Operator): def __call__(self, x): residual = x if self._mean is None else x-self._mean icovres = residual if self._icov is None else self._icov(residual) res = .5*(residual*icovres).sum() res = .5*residual.vdot(icovres) if not isinstance(x, Linearization): return res metric = SandwichOperator.make(x.jac, self._icov) ... ...
 ... ... @@ -43,7 +43,7 @@ class PoissonianEnergy(Operator): def __call__(self, x): x = self._op(x) res = (x - self._d*x.log()).sum() res = x.sum() - x.log().vdot(self._d) if not isinstance(x, Linearization): return res metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) ... ...
 ... ... @@ -33,8 +33,7 @@ class Linearization(object): @property def gradient(self): """Only available if target is a scalar""" from .sugar import full return self._jac.adjoint_times(full(self._jac.target, 1.)) return self._jac.adjoint_times(Field(self._jac.target, 1.)) @property def metric(self): ... ... @@ -73,14 +72,12 @@ class Linearization(object): def __mul__(self, other): from .sugar import makeOp from .operators.relaxed_sum_operator import RelaxedSumOperator if isinstance(other, Linearization): d1 = makeOp(self._val) d2 = makeOp(other._val) return Linearization( self._val*other._val, RelaxedSumOperator((d2.chain(self._jac), d1.chain(other._jac)))) d2.chain(self._jac) + d1.chain(other._jac)) if isinstance(other, (int, float, complex)): # if other == 0: # return ... ... ... @@ -99,12 +96,25 @@ class Linearization(object): d1 = makeOp(other) return Linearization(self._val*other, d1.chain(self._jac)) def vdot(self, other): from .domain_tuple import DomainTuple from .operators.vdot_operator import VdotOperator if isinstance(other, (Field, MultiField)): return Linearization( Field(DomainTuple.scalar_domain(),self._val.vdot(other)), VdotOperator(other).chain(self._jac)) return Linearization( Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)), VdotOperator(self._val).chain(other._jac) + VdotOperator(other._val).chain(self._jac)) def sum(self): from .domain_tuple import DomainTuple from .operators.vdot_operator import SumReductionOperator from .sugar import full from .operators.vdot_operator import VdotOperator return Linearization( full((), self._val.sum()), VdotOperator(full(self._jac.target, 1)).chain(self._jac)) Field(DomainTuple.scalar_domain(), self._val.sum()), SumReductionOperator(self._jac.target).chain(self._jac)) def exp(self): tmp = self._val.exp() ... ...
 ... ... @@ -25,13 +25,14 @@ from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain from .linear_operator import LinearOperator from ..sugar import full from ..field import Field class VdotOperator(LinearOperator): def __init__(self, field): super(VdotOperator, self).__init__() self._field = field self._target = DomainTuple.make(()) self._target = DomainTuple.scalar_domain() @property def domain(self): ... ... @@ -48,5 +49,30 @@ class VdotOperator(LinearOperator): def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: return full(self._target, self._field.vdot(x)) return Field(self._target, self._field.vdot(x)) return self._field*x.local_data[()] class SumReductionOperator(LinearOperator): def __init__(self, domain): super(SumReductionOperator, self).__init__() self._domain = domain self._target = DomainTuple.scalar_domain() @property def domain(self): return self._domain @property def target(self): return self._target @property def capability(self): return self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: return Field(self._target, x.sum()) return full(self._domain, x.local_data[()])
 ... ... @@ -73,4 +73,4 @@ class Energy_Tests(unittest.TestCase): energy, xi0, ntries=10) else: ift.extra.check_value_gradient_consistency( energy, xi0, ntries=10) energy, xi0, ntries=10, tol=5e-8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment