Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
9e71ab4c
Commit
9e71ab4c
authored
Jun 02, 2021
by
Philipp Frank
Browse files
docstrings
parent
2cf46dc5
Pipeline
#102770
passed with stages
in 13 minutes and 46 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Sidebyside
src/minimization/energy_adapter.py
View file @
9e71ab4c
...
...
@@ 23,6 +23,8 @@ from ..minimization.energy import Energy
from
..utilities
import
myassert
,
allreduce_sum
from
..multi_domain
import
MultiDomain
from
..sugar
import
from_random
from
..domain_tuple
import
DomainTuple
class
EnergyAdapter
(
Energy
):
"""Helper class which provides the traditional Nifty Energy interface to
...
...
@@ 90,28 +92,20 @@ class EnergyAdapter(Energy):
class
StochasticEnergyAdapter
(
Energy
):
"""A variant of `EnergyAdapter` that provides the energy interface for an
operator with a scalar target where parts of the imput are averaged
instead of optmized. Specifically, for the input corresponding to `keys`
a set of standart normal distributed samples are drawn and each gets
partially inserted into `bigop`. The results are averaged and represent a
stochastic average of an energy with the remaining subdomain being the DOFs
that are considered to be optimization parameters.
"""
def
__init__
(
self
,
position
,
bigop
,
keys
,
local_ops
,
n_samples
,
comm
,
nanisinf
,
def
__init__
(
self
,
position
,
op
,
keys
,
local_ops
,
n_samples
,
comm
,
nanisinf
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
StochasticEnergyAdapter
,
self
).
__init__
(
position
)
for
op
in
local_ops
:
myassert
(
position
.
domain
==
op
.
domain
)
for
l
op
in
local_ops
:
myassert
(
position
.
domain
==
l
op
.
domain
)
self
.
_comm
=
comm
self
.
_local_ops
=
local_ops
self
.
_n_samples
=
n_samples
lin
=
Linearization
.
make_var
(
position
)
v
,
g
=
[],
[]
for
op
in
self
.
_local_ops
:
tmp
=
op
(
lin
)
for
l
op
in
self
.
_local_ops
:
tmp
=
l
op
(
lin
)
v
.
append
(
tmp
.
val
.
val
)
g
.
append
(
tmp
.
gradient
)
self
.
_val
=
allreduce_sum
(
v
,
self
.
_comm
)[()]
/
self
.
_n_samples
...
...
@@ 119,7 +113,7 @@ class StochasticEnergyAdapter(Energy):
self
.
_val
=
np
.
inf
self
.
_grad
=
allreduce_sum
(
g
,
self
.
_comm
)
/
self
.
_n_samples
self
.
_op
=
big
op
self
.
_op
=
op
self
.
_keys
=
keys
@
property
...
...
@@ 131,8 +125,9 @@ class StochasticEnergyAdapter(Energy):
return
self
.
_grad
def
at
(
self
,
position
):
return
StochasticEnergyAdapter
(
position
,
self
.
_local_ops
,
self
.
_n_samples
,
self
.
_comm
,
self
.
_nanisinf
)
return
StochasticEnergyAdapter
(
position
,
self
.
_op
,
self
.
_keys
,
self
.
_local_ops
,
self
.
_n_samples
,
self
.
_comm
,
self
.
_nanisinf
,
_callingfrommake
=
True
)
def
apply_metric
(
self
,
x
):
lin
=
Linearization
.
make_var
(
self
.
position
,
want_metric
=
True
)
...
...
@@ 149,20 +144,56 @@ class StochasticEnergyAdapter(Energy):
def
resample_at
(
self
,
position
):
return
StochasticEnergyAdapter
.
make
(
position
,
self
.
_op
,
self
.
_keys
,
self
.
_n_samples
,
self
.
_comm
)
self
.
_n_samples
,
self
.
_comm
)
@
staticmethod
def
make
(
position
,
op
,
keys
,
n_samples
,
mirror_samples
,
nanisinf
=
False
,
comm
=
None
):
"""Energy adapter where parts of the model are sampled.
def
make
(
position
,
op
,
sampling_keys
,
n_samples
,
mirror_samples
,
comm
=
None
,
nanisinf
=
False
):
"""A variant of `EnergyAdapter` that provides the energy interface for an
operator with a scalar target where parts of the imput are averaged
instead of optmized.
Specifically, a set of standart normal distributed
samples are drawn for the input corresponding to `keys` and each sample
gets partially inserted into `op`. The resulting operators are averaged and
represent a stochastic average of an energy with the remaining subdomain
being the DOFs that are considered to be optimization parameters.
Parameters

position : MultiField
Values of the optimization parameters
op : Operator
The objective function of the optimization problem. Must have a
scalar target. The domain must be a `MultiDomain` with its keys
being the union of `sampling_keys` and `position.domain.keys()`.
sampling_keys : iterable of String
The keys of the subdomain over which the stochastic average of `op`
should be performed.
n_samples : int
Number of samples used for the stochastic estimate.
mirror_samples : boolean
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample
and its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the
forward model are interpreted as inf. Thereby, the code does not
crash on these occasions but rather the minimizer is told that the
position it has tried is not sensible.
"""
myassert
(
op
.
target
==
DomainTuple
.
scalar_domain
())
samdom
=
{}
for
k
in
keys
:
if
k
in
position
.
domain
.
keys
():
raise
ValueError
if
k
not
in
op
.
domain
.
keys
():
if
not
isinstance
(
n_samples
,
int
)
:
raise
TypeError
for
k
in
sampling_keys
:
if
(
k
in
position
.
domain
.
keys
())
or
(
k
not
in
op
.
domain
.
keys
()
)
:
raise
ValueError
else
:
samdom
[
k
]
=
op
.
domain
[
k
]
samdom
[
k
]
=
op
.
domain
[
k
]
samdom
=
MultiDomain
.
make
(
samdom
)
local_ops
=
[]
sseq
=
random
.
spawn_sseq
(
n_samples
)
...
...
@@ 176,5 +207,5 @@ class StochasticEnergyAdapter(Energy):
if
mirror_samples
:
local_ops
.
append
(
op
.
simplify_for_constant_input
(

rnd
)[
1
])
n_samples
=
2
*
n_samples
if
mirror_samples
else
n_samples
return
StochasticEnergyAdapter
(
position
,
op
,
keys
,
local_ops
,
n_samples
,
comm
,
nanisinf
,
_callingfrommake
=
True
)
return
StochasticEnergyAdapter
(
position
,
op
,
sampling_
keys
,
local_ops
,
n_samples
,
comm
,
nanisinf
,
_callingfrommake
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment