diff --git a/src/minimization/energy_adapter.py b/src/minimization/energy_adapter.py index f26ee439fd6e3d908c53dd73baa9c9dbb7acb6ce..c9f486821346ee055e7f5d0331c1d2fb62286603 100644 --- a/src/minimization/energy_adapter.py +++ b/src/minimization/energy_adapter.py @@ -23,6 +23,8 @@ from ..minimization.energy import Energy from ..utilities import myassert, allreduce_sum from ..multi_domain import MultiDomain from ..sugar import from_random +from ..domain_tuple import DomainTuple + class EnergyAdapter(Energy): """Helper class which provides the traditional Nifty Energy interface to @@ -90,28 +92,20 @@ class EnergyAdapter(Energy): class StochasticEnergyAdapter(Energy): - """A variant of `EnergyAdapter` that provides the energy interface for an - operator with a scalar target where parts of the imput are averaged - instead of optmized. Specifically, for the input corresponding to `keys` - a set of standart normal distributed samples are drawn and each gets - partially inserted into `bigop`. The results are averaged and represent a - stochastic average of an energy with the remaining subdomain being the DOFs - that are considered to be optimization parameters. - """ - def __init__(self, position, bigop, keys, local_ops, n_samples, comm, nanisinf, + def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf, _callingfrommake=False): if not _callingfrommake: raise NotImplementedError super(StochasticEnergyAdapter, self).__init__(position) - for op in local_ops: - myassert(position.domain == op.domain) + for lop in local_ops: + myassert(position.domain == lop.domain) self._comm = comm self._local_ops = local_ops self._n_samples = n_samples lin = Linearization.make_var(position) v, g = [], [] - for op in self._local_ops: - tmp = op(lin) + for lop in self._local_ops: + tmp = lop(lin) v.append(tmp.val.val) g.append(tmp.gradient) self._val = allreduce_sum(v, self._comm)[()]/self._n_samples @@ -119,7 +113,7 @@ class StochasticEnergyAdapter(Energy): self._val = np.inf self._grad = allreduce_sum(g, self._comm)/self._n_samples - self._op = bigop + self._op = op self._keys = keys @property @@ -131,8 +125,9 @@ class StochasticEnergyAdapter(Energy): return self._grad def at(self, position): - return StochasticEnergyAdapter(position, self._local_ops, - self._n_samples, self._comm, self._nanisinf) + return StochasticEnergyAdapter(position, self._op, self._keys, + self._local_ops, self._n_samples, self._comm, self._nanisinf, + _callingfrommake=True) def apply_metric(self, x): lin = Linearization.make_var(self.position, want_metric=True) @@ -149,20 +144,56 @@ class StochasticEnergyAdapter(Energy): def resample_at(self, position): return StochasticEnergyAdapter.make(position, self._op, self._keys, - self._n_samples, self._comm) + self._n_samples, self._comm) @staticmethod - def make(position, op, keys, n_samples, mirror_samples, nanisinf = False, comm=None): - """Energy adapter where parts of the model are sampled. + def make(position, op, sampling_keys, n_samples, mirror_samples, + comm=None, nanisinf = False): + """A variant of `EnergyAdapter` that provides the energy interface for an + operator with a scalar target where parts of the imput are averaged + instead of optmized. + + Specifically, a set of standart normal distributed + samples are drawn for the input corresponding to `keys` and each sample + gets partially inserted into `op`. The resulting operators are averaged and + represent a stochastic average of an energy with the remaining subdomain + being the DOFs that are considered to be optimization parameters. + + Parameters + ---------- + position : MultiField + Values of the optimization parameters + op : Operator + The objective function of the optimization problem. Must have a + scalar target. The domain must be a `MultiDomain` with its keys + being the union of `sampling_keys` and `position.domain.keys()`. + sampling_keys : iterable of String + The keys of the subdomain over which the stochastic average of `op` + should be performed. + n_samples : int + Number of samples used for the stochastic estimate. + mirror_samples : boolean + Whether the negative of the drawn samples are also used, as they are + equally legitimate samples. If true, the number of used samples + doubles. + comm : MPI communicator or None + If not None, samples will be distributed as evenly as possible + across this communicator. If `mirror_samples` is set, then a sample + and its mirror image will always reside on the same task. + nanisinf : bool + If true, nan energies which can happen due to overflows in the + forward model are interpreted as inf. Thereby, the code does not + crash on these occasions but rather the minimizer is told that the + position it has tried is not sensible. """ + myassert(op.target == DomainTuple.scalar_domain()) samdom = {} - for k in keys: - if k in position.domain.keys(): - raise ValueError - if k not in op.domain.keys(): + if not isinstance(n_samples, int): + raise TypeError + for k in sampling_keys: + if (k in position.domain.keys()) or (k not in op.domain.keys()): raise ValueError - else: - samdom[k] = op.domain[k] + samdom[k] = op.domain[k] samdom = MultiDomain.make(samdom) local_ops = [] sseq = random.spawn_sseq(n_samples) @@ -176,5 +207,5 @@ class StochasticEnergyAdapter(Energy): if mirror_samples: local_ops.append(op.simplify_for_constant_input(-rnd)[1]) n_samples = 2*n_samples if mirror_samples else n_samples - return StochasticEnergyAdapter(position, op, keys, local_ops, n_samples, - comm, nanisinf, _callingfrommake=True) + return StochasticEnergyAdapter(position, op, sampling_keys, local_ops, + n_samples, comm, nanisinf, _callingfrommake=True)