Commit a13d5e4a authored by Philipp Arras's avatar Philipp Arras
Browse files

Performance fixups 4/n

parent 537234a4
Pipeline #70460 passed with stages
in 15 minutes and 53 seconds
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..field import Field from ..field import Field
from ..multi_field import MultiField from ..multi_field import MultiField
from .operator import Operator from .operator import Operator
from ..sugar import makeDomain
class Adder(Operator): class Adder(Operator):
...@@ -25,18 +28,22 @@ class Adder(Operator): ...@@ -25,18 +28,22 @@ class Adder(Operator):
Parameters Parameters
---------- ----------
field : Field or MultiField a : Field or MultiField or Scalar
The field by which the input is shifted. The field by which the input is shifted.
""" """
def __init__(self, field, neg=False): def __init__(self, a, neg=False, domain=None):
if not isinstance(field, (Field, MultiField)): self._a = a
if isinstance(a, (Field, MultiField)):
dom = a.domain
elif np.isscalar(a):
dom = makeDomain(domain)
else:
raise TypeError raise TypeError
self._field = field self._domain = self._target = dom
self._domain = self._target = field.domain
self._neg = bool(neg) self._neg = bool(neg)
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
if self._neg: if self._neg:
return x - self._field return x - self._a
return x + self._field return x + self._a
...@@ -136,9 +136,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -136,9 +136,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res0 = r.vdot(r*icov).real res0 = r.vdot(r*icov).real
res1 = icov.log().sum() res1 = icov.log().sum()
res = (res0-res1).scale(0.5)(x) res = (res0-res1).scale(0.5)(x)
if not lin: if not lin or not x.want_metric:
return Field.scalar(res)
if not x.want_metric:
return res return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)} mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
metric = makeOp(MultiField.from_dict(mf)) metric = makeOp(MultiField.from_dict(mf))
...@@ -242,9 +240,7 @@ class PoissonianEnergy(EnergyOperator): ...@@ -242,9 +240,7 @@ class PoissonianEnergy(EnergyOperator):
self._check_input(x) self._check_input(x)
fa = FieldAdapter(self._domain, 'foo') fa = FieldAdapter(self._domain, 'foo')
res = (fa.sum() - fa.log().vdot(self._d))(fa.adjoint(x)) res = (fa.sum() - fa.log().vdot(self._d))(fa.adjoint(x))
if not isinstance(x, Linearization): if not isinstance(x, Linearization) or not x.want_metric:
return Field.scalar(res)
if not x.want_metric:
return res return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric) return res.add_metric(metric)
...@@ -287,9 +283,7 @@ class InverseGammaLikelihood(EnergyOperator): ...@@ -287,9 +283,7 @@ class InverseGammaLikelihood(EnergyOperator):
fa = FieldAdapter(self._domain, 'foo') fa = FieldAdapter(self._domain, 'foo')
x = fa.adjoint(x) x = fa.adjoint(x)
res = (fa.log().vdot(self._alphap1) + fa.one_over().vdot(self._beta))(x) res = (fa.log().vdot(self._alphap1) + fa.one_over().vdot(self._beta))(x)
if not isinstance(x, Linearization): if not isinstance(x, Linearization) or not x.want_metric:
return Field.scalar(res)
if not x.want_metric:
return res return res
metric = SandwichOperator.make(x.jac, makeOp(self._alphap1/(x.val**2))) metric = SandwichOperator.make(x.jac, makeOp(self._alphap1/(x.val**2)))
return res.add_metric(metric) return res.add_metric(metric)
...@@ -357,11 +351,9 @@ class BernoulliEnergy(EnergyOperator): ...@@ -357,11 +351,9 @@ class BernoulliEnergy(EnergyOperator):
self._check_input(x) self._check_input(x)
iden = FieldAdapter(self._domain, 'foo') iden = FieldAdapter(self._domain, 'foo')
from .adder import Adder from .adder import Adder
v = -iden.log().vdot(self._d) + (Adder(Field.full(self._domain, 1.)) @ iden.scale(-1)).log().vdot(self._d-1.) v = -iden.log().vdot(self._d) + (Adder(1, domain=self._domain) @ iden.scale(-1)).log().vdot(self._d-1.)
v = v(iden.adjoint(x)) v = v(iden.adjoint(x))
if not isinstance(x, Linearization): if not isinstance(x, Linearization) or not x.want_metric:
return Field.scalar(v)
if not x.want_metric:
return v return v
met = makeOp(1./(x.val*(1. - x.val))) met = makeOp(1./(x.val*(1. - x.val)))
met = SandwichOperator.make(x.jac, met) met = SandwichOperator.make(x.jac, met)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment