Commit 9e71ab4c authored by Philipp Frank's avatar Philipp Frank
Browse files

docstrings

parent 2cf46dc5
Pipeline #102770 passed with stages
in 13 minutes and 46 seconds
...@@ -23,6 +23,8 @@ from ..minimization.energy import Energy ...@@ -23,6 +23,8 @@ from ..minimization.energy import Energy
from ..utilities import myassert, allreduce_sum from ..utilities import myassert, allreduce_sum
from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
from ..sugar import from_random from ..sugar import from_random
from ..domain_tuple import DomainTuple
class EnergyAdapter(Energy): class EnergyAdapter(Energy):
"""Helper class which provides the traditional Nifty Energy interface to """Helper class which provides the traditional Nifty Energy interface to
...@@ -90,28 +92,20 @@ class EnergyAdapter(Energy): ...@@ -90,28 +92,20 @@ class EnergyAdapter(Energy):
class StochasticEnergyAdapter(Energy): class StochasticEnergyAdapter(Energy):
"""A variant of `EnergyAdapter` that provides the energy interface for an def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
operator with a scalar target where parts of the imput are averaged
instead of optmized. Specifically, for the input corresponding to `keys`
a set of standart normal distributed samples are drawn and each gets
partially inserted into `bigop`. The results are averaged and represent a
stochastic average of an energy with the remaining subdomain being the DOFs
that are considered to be optimization parameters.
"""
def __init__(self, position, bigop, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False): _callingfrommake=False):
if not _callingfrommake: if not _callingfrommake:
raise NotImplementedError raise NotImplementedError
super(StochasticEnergyAdapter, self).__init__(position) super(StochasticEnergyAdapter, self).__init__(position)
for op in local_ops: for lop in local_ops:
myassert(position.domain == op.domain) myassert(position.domain == lop.domain)
self._comm = comm self._comm = comm
self._local_ops = local_ops self._local_ops = local_ops
self._n_samples = n_samples self._n_samples = n_samples
lin = Linearization.make_var(position) lin = Linearization.make_var(position)
v, g = [], [] v, g = [], []
for op in self._local_ops: for lop in self._local_ops:
tmp = op(lin) tmp = lop(lin)
v.append(tmp.val.val) v.append(tmp.val.val)
g.append(tmp.gradient) g.append(tmp.gradient)
self._val = allreduce_sum(v, self._comm)[()]/self._n_samples self._val = allreduce_sum(v, self._comm)[()]/self._n_samples
...@@ -119,7 +113,7 @@ class StochasticEnergyAdapter(Energy): ...@@ -119,7 +113,7 @@ class StochasticEnergyAdapter(Energy):
self._val = np.inf self._val = np.inf
self._grad = allreduce_sum(g, self._comm)/self._n_samples self._grad = allreduce_sum(g, self._comm)/self._n_samples
self._op = bigop self._op = op
self._keys = keys self._keys = keys
@property @property
...@@ -131,8 +125,9 @@ class StochasticEnergyAdapter(Energy): ...@@ -131,8 +125,9 @@ class StochasticEnergyAdapter(Energy):
return self._grad return self._grad
def at(self, position): def at(self, position):
return StochasticEnergyAdapter(position, self._local_ops, return StochasticEnergyAdapter(position, self._op, self._keys,
self._n_samples, self._comm, self._nanisinf) self._local_ops, self._n_samples, self._comm, self._nanisinf,
_callingfrommake=True)
def apply_metric(self, x): def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True) lin = Linearization.make_var(self.position, want_metric=True)
...@@ -149,20 +144,56 @@ class StochasticEnergyAdapter(Energy): ...@@ -149,20 +144,56 @@ class StochasticEnergyAdapter(Energy):
def resample_at(self, position): def resample_at(self, position):
return StochasticEnergyAdapter.make(position, self._op, self._keys, return StochasticEnergyAdapter.make(position, self._op, self._keys,
self._n_samples, self._comm) self._n_samples, self._comm)
@staticmethod @staticmethod
def make(position, op, keys, n_samples, mirror_samples, nanisinf = False, comm=None): def make(position, op, sampling_keys, n_samples, mirror_samples,
"""Energy adapter where parts of the model are sampled. comm=None, nanisinf = False):
"""A variant of `EnergyAdapter` that provides the energy interface for an
operator with a scalar target where parts of the imput are averaged
instead of optmized.
Specifically, a set of standart normal distributed
samples are drawn for the input corresponding to `keys` and each sample
gets partially inserted into `op`. The resulting operators are averaged and
represent a stochastic average of an energy with the remaining subdomain
being the DOFs that are considered to be optimization parameters.
Parameters
----------
position : MultiField
Values of the optimization parameters
op : Operator
The objective function of the optimization problem. Must have a
scalar target. The domain must be a `MultiDomain` with its keys
being the union of `sampling_keys` and `position.domain.keys()`.
sampling_keys : iterable of String
The keys of the subdomain over which the stochastic average of `op`
should be performed.
n_samples : int
Number of samples used for the stochastic estimate.
mirror_samples : boolean
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample
and its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the
forward model are interpreted as inf. Thereby, the code does not
crash on these occasions but rather the minimizer is told that the
position it has tried is not sensible.
""" """
myassert(op.target == DomainTuple.scalar_domain())
samdom = {} samdom = {}
for k in keys: if not isinstance(n_samples, int):
if k in position.domain.keys(): raise TypeError
raise ValueError for k in sampling_keys:
if k not in op.domain.keys(): if (k in position.domain.keys()) or (k not in op.domain.keys()):
raise ValueError raise ValueError
else: samdom[k] = op.domain[k]
samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom) samdom = MultiDomain.make(samdom)
local_ops = [] local_ops = []
sseq = random.spawn_sseq(n_samples) sseq = random.spawn_sseq(n_samples)
...@@ -176,5 +207,5 @@ class StochasticEnergyAdapter(Energy): ...@@ -176,5 +207,5 @@ class StochasticEnergyAdapter(Energy):
if mirror_samples: if mirror_samples:
local_ops.append(op.simplify_for_constant_input(-rnd)[1]) local_ops.append(op.simplify_for_constant_input(-rnd)[1])
n_samples = 2*n_samples if mirror_samples else n_samples n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, keys, local_ops, n_samples, return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
comm, nanisinf, _callingfrommake=True) n_samples, comm, nanisinf, _callingfrommake=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment