Commit a13d5e4a authored by Philipp Arras's avatar Philipp Arras
Browse files

Performance fixups 4/n

parent 537234a4
Pipeline #70460 passed with stages
in 15 minutes and 53 seconds
......@@ -15,9 +15,12 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..field import Field
from ..multi_field import MultiField
from .operator import Operator
from ..sugar import makeDomain
class Adder(Operator):
......@@ -25,18 +28,22 @@ class Adder(Operator):
Parameters
----------
field : Field or MultiField
a : Field or MultiField or Scalar
The field by which the input is shifted.
"""
def __init__(self, field, neg=False):
if not isinstance(field, (Field, MultiField)):
def __init__(self, a, neg=False, domain=None):
self._a = a
if isinstance(a, (Field, MultiField)):
dom = a.domain
elif np.isscalar(a):
dom = makeDomain(domain)
else:
raise TypeError
self._field = field
self._domain = self._target = field.domain
self._domain = self._target = dom
self._neg = bool(neg)
def apply(self, x):
self._check_input(x)
if self._neg:
return x - self._field
return x + self._field
return x - self._a
return x + self._a
......@@ -136,9 +136,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res0 = r.vdot(r*icov).real
res1 = icov.log().sum()
res = (res0-res1).scale(0.5)(x)
if not lin:
return Field.scalar(res)
if not x.want_metric:
if not lin or not x.want_metric:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
metric = makeOp(MultiField.from_dict(mf))
......@@ -242,9 +240,7 @@ class PoissonianEnergy(EnergyOperator):
self._check_input(x)
fa = FieldAdapter(self._domain, 'foo')
res = (fa.sum() - fa.log().vdot(self._d))(fa.adjoint(x))
if not isinstance(x, Linearization):
return Field.scalar(res)
if not x.want_metric:
if not isinstance(x, Linearization) or not x.want_metric:
return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric)
......@@ -287,9 +283,7 @@ class InverseGammaLikelihood(EnergyOperator):
fa = FieldAdapter(self._domain, 'foo')
x = fa.adjoint(x)
res = (fa.log().vdot(self._alphap1) + fa.one_over().vdot(self._beta))(x)
if not isinstance(x, Linearization):
return Field.scalar(res)
if not x.want_metric:
if not isinstance(x, Linearization) or not x.want_metric:
return res
metric = SandwichOperator.make(x.jac, makeOp(self._alphap1/(x.val**2)))
return res.add_metric(metric)
......@@ -357,11 +351,9 @@ class BernoulliEnergy(EnergyOperator):
self._check_input(x)
iden = FieldAdapter(self._domain, 'foo')
from .adder import Adder
v = -iden.log().vdot(self._d) + (Adder(Field.full(self._domain, 1.)) @ iden.scale(-1)).log().vdot(self._d-1.)
v = -iden.log().vdot(self._d) + (Adder(1, domain=self._domain) @ iden.scale(-1)).log().vdot(self._d-1.)
v = v(iden.adjoint(x))
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
if not isinstance(x, Linearization) or not x.want_metric:
return v
met = makeOp(1./(x.val*(1. - x.val)))
met = SandwichOperator.make(x.jac, met)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment