diff --git a/src/minimization/energy_adapter.py b/src/minimization/energy_adapter.py
index f26ee439fd6e3d908c53dd73baa9c9dbb7acb6ce..c9f486821346ee055e7f5d0331c1d2fb62286603 100644
--- a/src/minimization/energy_adapter.py
+++ b/src/minimization/energy_adapter.py
@@ -23,6 +23,8 @@ from ..minimization.energy import Energy
from ..utilities import myassert, allreduce_sum
from ..multi_domain import MultiDomain
from ..sugar import from_random
+from ..domain_tuple import DomainTuple
+
class EnergyAdapter(Energy):
"""Helper class which provides the traditional Nifty Energy interface to
@@ -90,28 +92,20 @@ class EnergyAdapter(Energy):
class StochasticEnergyAdapter(Energy):
- """A variant of `EnergyAdapter` that provides the energy interface for an
- operator with a scalar target where parts of the imput are averaged
- instead of optmized. Specifically, for the input corresponding to `keys`
- a set of standart normal distributed samples are drawn and each gets
- partially inserted into `bigop`. The results are averaged and represent a
- stochastic average of an energy with the remaining subdomain being the DOFs
- that are considered to be optimization parameters.
- """
- def __init__(self, position, bigop, keys, local_ops, n_samples, comm, nanisinf,
+ def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(StochasticEnergyAdapter, self).__init__(position)
- for op in local_ops:
- myassert(position.domain == op.domain)
+ for lop in local_ops:
+ myassert(position.domain == lop.domain)
self._comm = comm
self._local_ops = local_ops
self._n_samples = n_samples
lin = Linearization.make_var(position)
v, g = [], []
- for op in self._local_ops:
- tmp = op(lin)
+ for lop in self._local_ops:
+ tmp = lop(lin)
v.append(tmp.val.val)
g.append(tmp.gradient)
self._val = allreduce_sum(v, self._comm)[()]/self._n_samples
@@ -119,7 +113,7 @@ class StochasticEnergyAdapter(Energy):
self._val = np.inf
self._grad = allreduce_sum(g, self._comm)/self._n_samples
- self._op = bigop
+ self._op = op
self._keys = keys
@property
@@ -131,8 +125,9 @@ class StochasticEnergyAdapter(Energy):
return self._grad
def at(self, position):
- return StochasticEnergyAdapter(position, self._local_ops,
- self._n_samples, self._comm, self._nanisinf)
+ return StochasticEnergyAdapter(position, self._op, self._keys,
+ self._local_ops, self._n_samples, self._comm, self._nanisinf,
+ _callingfrommake=True)
def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True)
@@ -149,20 +144,56 @@ class StochasticEnergyAdapter(Energy):
def resample_at(self, position):
return StochasticEnergyAdapter.make(position, self._op, self._keys,
- self._n_samples, self._comm)
+ self._n_samples, self._comm)
@staticmethod
- def make(position, op, keys, n_samples, mirror_samples, nanisinf = False, comm=None):
- """Energy adapter where parts of the model are sampled.
+ def make(position, op, sampling_keys, n_samples, mirror_samples,
+ comm=None, nanisinf = False):
+ """A variant of `EnergyAdapter` that provides the energy interface for an
+ operator with a scalar target where parts of the imput are averaged
+ instead of optmized.
+
+ Specifically, a set of standart normal distributed
+ samples are drawn for the input corresponding to `keys` and each sample
+ gets partially inserted into `op`. The resulting operators are averaged and
+ represent a stochastic average of an energy with the remaining subdomain
+ being the DOFs that are considered to be optimization parameters.
+
+ Parameters
+ ----------
+ position : MultiField
+ Values of the optimization parameters
+ op : Operator
+ The objective function of the optimization problem. Must have a
+ scalar target. The domain must be a `MultiDomain` with its keys
+ being the union of `sampling_keys` and `position.domain.keys()`.
+ sampling_keys : iterable of String
+ The keys of the subdomain over which the stochastic average of `op`
+ should be performed.
+ n_samples : int
+ Number of samples used for the stochastic estimate.
+ mirror_samples : boolean
+ Whether the negative of the drawn samples are also used, as they are
+ equally legitimate samples. If true, the number of used samples
+ doubles.
+ comm : MPI communicator or None
+ If not None, samples will be distributed as evenly as possible
+ across this communicator. If `mirror_samples` is set, then a sample
+ and its mirror image will always reside on the same task.
+ nanisinf : bool
+ If true, nan energies which can happen due to overflows in the
+ forward model are interpreted as inf. Thereby, the code does not
+ crash on these occasions but rather the minimizer is told that the
+ position it has tried is not sensible.
"""
+ myassert(op.target == DomainTuple.scalar_domain())
samdom = {}
- for k in keys:
- if k in position.domain.keys():
- raise ValueError
- if k not in op.domain.keys():
+ if not isinstance(n_samples, int):
+ raise TypeError
+ for k in sampling_keys:
+ if (k in position.domain.keys()) or (k not in op.domain.keys()):
raise ValueError
- else:
- samdom[k] = op.domain[k]
+ samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom)
local_ops = []
sseq = random.spawn_sseq(n_samples)
@@ -176,5 +207,5 @@ class StochasticEnergyAdapter(Energy):
if mirror_samples:
local_ops.append(op.simplify_for_constant_input(-rnd)[1])
n_samples = 2*n_samples if mirror_samples else n_samples
- return StochasticEnergyAdapter(position, op, keys, local_ops, n_samples,
- comm, nanisinf, _callingfrommake=True)
+ return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
+ n_samples, comm, nanisinf, _callingfrommake=True)