Commit f30d1547 authored by Philipp Frank's avatar Philipp Frank
Browse files

restructure

parent ac4274c5
Pipeline #102736 failed with stages
in 50 seconds
......@@ -72,7 +72,7 @@ from .minimization.stochastic_minimizer import ADVIOptimizer
from .minimization.scipy_minimizer import L_BFGS_B
from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter
from .minimization.energy_adapter import EnergyAdapter, StochasticEnergyAdapter
from .minimization.kl_energies import MetricGaussianKL, GeoMetricKL
from .sugar import *
......
......@@ -17,10 +17,13 @@
import numpy as np
from .. import random
from ..linearization import Linearization
from ..minimization.energy import Energy
from ..sugar import makeDomain
from ..utilities import myassert, allreduce_sum
from ..multi_domain import MultiDomain
from ..sugar import from_random
from .kl_energies import _SelfAdjointOperatorWrapper, _get_lo_hi
class EnergyAdapter(Energy):
"""Helper class which provides the traditional Nifty Energy interface to
......@@ -85,3 +88,84 @@ class EnergyAdapter(Energy):
def apply_metric(self, x):
return self._metric(x)
class StochasticEnergyAdapter(Energy):
def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(StochasticEnergyAdapter, self).__init__(position)
for op in local_ops:
myassert(position.domain == op.domain)
self._comm = comm
self._local_ops = local_ops
self._n_samples = n_samples
lin = Linearization.make_var(position)
v, g = [], []
for op in self._local_ops:
tmp = op(lin)
v.append(tmp.val.val)
g.append(tmp.gradient)
self._val = allreduce_sum(v, self._comm)[()]/self._n_samples
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = allreduce_sum(g, self._comm)/self._n_samples
self._op = op
self._keys = keys
@property
def value(self):
return self._val
@property
def gradient(self):
return self._grad
def at(self, position):
return StochasticEnergyAdapter(position, self._local_ops,
self._n_samples, self._comm, self._nanisinf)
def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True)
res = []
for op in self._local_ops:
res.append(op(lin).metric(x))
return allreduce_sum(res, self._comm)/self._n_samples
@property
def metric(self):
return _SelfAdjointOperatorWrapper(self.position.domain,
self.apply_metric)
def resample_at(self, position):
return StochasticEnergyAdapter.make(position, self._op, self._keys,
self._n_samples, self._comm)
@staticmethod
def make(position, op, keys, n_samples, mirror_samples, nanisinf = False, comm=None):
"""Energy adapter where parts of the model are sampled.
"""
samdom = {}
for k in keys:
if k in position.domain.keys():
raise ValueError
if k not in op.domain.keys():
raise ValueError
else:
samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom)
local_ops = []
sseq = random.spawn_sseq(n_samples)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
rnd = from_random(samdom)
_, tmp = op.simplify_for_constant_input(rnd)
myassert(tmp.domain == position.domain)
local_ops.append(tmp)
if mirror_samples:
local_ops.append(op.simplify_for_constant_input(-rnd)[1])
n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, keys, local_ops, n_samples,
comm, nanisinf, _callingfrommake=True)
......@@ -15,107 +15,8 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .minimizer import Minimizer
from .energy import Energy
from .kl_energies import _SelfAdjointOperatorWrapper, _get_lo_hi
from ..linearization import Linearization
from ..utilities import myassert, allreduce_sum
from ..multi_domain import MultiDomain
from ..sugar import from_random
from .. import random
class _StochasticEnergyAdapter(Energy):
def __init__(self, position, local_ops, n_samples, comm, nanisinf):
super(_StochasticEnergyAdapter, self).__init__(position)
for op in local_ops:
myassert(position.domain == op.domain)
self._comm = comm
self._local_ops = local_ops
self._n_samples = n_samples
lin = Linearization.make_var(position)
v, g = [], []
for op in self._local_ops:
tmp = op(lin)
v.append(tmp.val.val)
g.append(tmp.gradient)
self._val = allreduce_sum(v, self._comm)[()]/self._n_samples
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = allreduce_sum(g, self._comm)/self._n_samples
@property
def value(self):
return self._val
@property
def gradient(self):
return self._grad
def at(self, position):
return _StochasticEnergyAdapter(position, self._local_ops,
self._n_samples, self._comm, self._nanisinf)
def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True)
res = []
for op in self._local_ops:
res.append(op(lin).metric(x))
return allreduce_sum(res, self._comm)/self._n_samples
@property
def metric(self):
return _SelfAdjointOperatorWrapper(self.position.domain,
self.apply_metric)
class PartialSampledEnergy(_StochasticEnergyAdapter):
def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(PartialSampledEnergy, self).__init__(position,
local_ops, n_samples, comm, nanisinf)
self._op = op
self._keys = keys
def at(self, position):
return PartialSampledEnergy(position, self._op, self._keys,
self._local_ops, self._n_samples,
self._comm, self._nanisinf,
_callingfrommake=True)
def resample_at(self, position):
return PartialSampledEnergy.make(position, self._op, self._keys,
self._n_samples, self._comm)
@staticmethod
def make(position, op, keys, n_samples, mirror_samples, nanisinf = False, comm=None):
samdom = {}
for k in keys:
if k in position.domain.keys():
raise ValueError
if k not in op.domain.keys():
raise ValueError
else:
samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom)
local_ops = []
sseq = random.spawn_sseq(n_samples)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
rnd = from_random(samdom)
_, tmp = op.simplify_for_constant_input(rnd)
myassert(tmp.domain == position.domain)
local_ops.append(tmp)
if mirror_samples:
local_ops.append(op.simplify_for_constant_input(-rnd)[1])
n_samples = 2*n_samples if mirror_samples else n_samples
return PartialSampledEnergy(position, op, keys, local_ops, n_samples,
comm, nanisinf, _callingfrommake=True)
class ADVIOptimizer(Minimizer):
"""Provide an implementation of an adaptive step-size sequence optimizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment