Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
537234a4
Commit
537234a4
authored
Mar 09, 2020
by
Philipp Arras
Browse files
Performance fixups 3/n
parent
a29abca8
Pipeline
#70458
failed with stages
in 9 minutes and 29 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/energy_operators.py
View file @
537234a4
...
...
@@ -355,7 +355,10 @@ class BernoulliEnergy(EnergyOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
v
=
-
(
x
.
log
().
vdot
(
self
.
_d
)
+
(
1.
-
x
).
log
().
vdot
(
1.
-
self
.
_d
))
iden
=
FieldAdapter
(
self
.
_domain
,
'foo'
)
from
.adder
import
Adder
v
=
-
iden
.
log
().
vdot
(
self
.
_d
)
+
(
Adder
(
Field
.
full
(
self
.
_domain
,
1.
))
@
iden
.
scale
(
-
1
)).
log
().
vdot
(
self
.
_d
-
1.
)
v
=
v
(
iden
.
adjoint
(
x
))
if
not
isinstance
(
x
,
Linearization
):
return
Field
.
scalar
(
v
)
if
not
x
.
want_metric
:
...
...
@@ -455,5 +458,11 @@ class AveragedEnergy(EnergyOperator):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
)
return
utilities
.
my_sum
(
mymap
)
*
(
1.
/
len
(
self
.
_res_samples
))
if
isinstance
(
self
.
_domain
,
MultiDomain
):
iden
=
ScalingOperator
(
self
.
_domain
,
1.
)
else
:
iden
=
FieldAdapter
(
self
.
_domain
,
'foo'
)
x
=
iden
.
adjoint
(
x
)
from
.adder
import
Adder
mymap
=
map
(
lambda
v
:
self
.
_h
(
Adder
(
v
)
@
iden
),
self
.
_res_samples
)
return
utilities
.
my_sum
(
mymap
).
scale
(
1.
/
len
(
self
.
_res_samples
))(
x
)
test/test_energy_gradients.py
View file @
537234a4
...
...
@@ -25,14 +25,13 @@ from itertools import product
# hopefully be fixed in the future.
# https://docs.pytest.org/en/latest/proposals/parametrize_with_fixtures.html
SPACES
=
[
ift
.
GLSpace
(
15
),
ift
.
RGSpace
(
64
,
distances
=
.
789
),
ift
.
RGSpace
([
32
,
32
],
distances
=
.
789
)
]
SPACES
=
[
ift
.
GLSpace
(
15
),
ift
.
RGSpace
(
64
,
distances
=
.
789
),
ift
.
RGSpace
([
32
,
32
],
distances
=
.
789
)]
SEEDS
=
[
4
,
78
,
23
]
PARAMS
=
product
(
SEEDS
,
SPACES
)
pmp
=
pytest
.
mark
.
parametrize
# FIXME Test also with multifields in domain
@
pytest
.
fixture
(
params
=
PARAMS
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment