Commit 537234a4 authored by Philipp Arras's avatar Philipp Arras
Browse files

Performance fixups 3/n

parent a29abca8
......@@ -355,7 +355,10 @@ class BernoulliEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
v = -(x.log().vdot(self._d) + (1. - x).log().vdot(1. - self._d))
iden = FieldAdapter(self._domain, 'foo')
from .adder import Adder
v = -iden.log().vdot(self._d) + (Adder(Field.full(self._domain, 1.)) @ iden.scale(-1)).log().vdot(self._d-1.)
v = v(iden.adjoint(x))
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
......@@ -455,5 +458,11 @@ class AveragedEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
mymap = map(lambda v: self._h(x + v), self._res_samples)
return utilities.my_sum(mymap)*(1./len(self._res_samples))
if isinstance(self._domain, MultiDomain):
iden = ScalingOperator(self._domain, 1.)
else:
iden = FieldAdapter(self._domain, 'foo')
x = iden.adjoint(x)
from .adder import Adder
mymap = map(lambda v: self._h(Adder(v) @ iden), self._res_samples)
return utilities.my_sum(mymap).scale(1./len(self._res_samples))(x)
......@@ -25,14 +25,13 @@ from itertools import product
# hopefully be fixed in the future.
# https://docs.pytest.org/en/latest/proposals/parametrize_with_fixtures.html
SPACES = [
ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)
]
SPACES = [ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)]
SEEDS = [4, 78, 23]
PARAMS = product(SEEDS, SPACES)
pmp = pytest.mark.parametrize
# FIXME Test also with multifields in domain
@pytest.fixture(params=PARAMS)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment