Commit 7d22d7be authored by Philipp Arras's avatar Philipp Arras

Implement proper constant support 1/n

parent de268998
Pipeline #76969 failed with stages
in 5 minutes and 1 second
......@@ -313,43 +313,43 @@ def _linearization_value_consistency(op, loc):
def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
metric_sampling):
return # FIXME
# Assumes that the operator is not constant
if isinstance(op.domain, DomainTuple):
return
return # FIXME ?
keys = op.domain.keys()
for ll in range(0, len(keys)):
for cstkeys in combinations(keys, ll):
cstdom, vardom = {}, {}
for kk, dd in op.domain.items():
if kk in cstkeys:
cstdom[kk] = dd
else:
vardom[kk] = dd
cstdom, vardom = makeDomain(cstdom), makeDomain(vardom)
cstloc = loc.extract(cstdom)
varkeys = set(keys) - set(cstkeys)
print(f'Constant: {set(cstkeys)}, Variable: {varkeys}')
cstloc = loc.extract_by_keys(cstkeys)
varloc = loc.extract_by_keys(varkeys)
val0 = op(loc)
_, op0 = op.simplify_for_constant_input(cstloc)
val1 = op0(loc)
# MR FIXME: This tests something we don't promise!
# val2 = op0(loc.unite(cstloc))
# assert_equal(val1, val2)
assert op0.domain is varloc.domain
val1 = op0(varloc)
assert_equal(val0, val1)
lin = Linearization.make_var(loc, want_metric=True)
oplin = op0(lin)
if isinstance(op, EnergyOperator):
_allzero(oplin.gradient.extract(cstdom))
# MR FIXME: This tests something we don't promise!
# _allzero(oplin.jac(from_random(cstdom).unite(full(vardom, 0))))
if isinstance(op, EnergyOperator) and metric_sampling:
samp0 = oplin.metric.draw_sample()
_allzero(samp0.extract(cstdom))
_nozero(samp0.extract(vardom))
_jac_vs_finite_differences(op0, loc, np.sqrt(tol), ntries, only_r_differentiable)
lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True)
lin0 = Linearization.make_var(varloc, want_metric=True)
oplin0 = op0(lin0)
oplin = op(lin)
assert oplin.jac.target is oplin0.jac.target
rndinp = from_random(oplin.jac.target)
assert_equal(oplin.jac.adjoint(rndinp).extract(varloc.domain), oplin0.jac.adjoint(rndinp))
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
assert_equal(foo, 0*foo)
# FIXME
# if isinstance(op, EnergyOperator):
# _allzero(oplin.gradient.extract(cstdom))
# if isinstance(op, EnergyOperator) and metric_sampling:
# samp0 = oplin.metric.draw_sample()
# _allzero(samp0.extract(cstdom))
# _nozero(samp0.extract(vardom))
assert op0.domain is varloc.domain
_jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, only_r_differentiable)
def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
......
......@@ -47,6 +47,22 @@ def _get_lo_hi(comm, n_samples):
return utilities.shareRange(n_samples, ntask, rank)
def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros
in sample if key is not in domain."""
from ..multi_domain import MultiDomain
if not isinstance(sample, MultiField):
assert sample.domain is domain
return sample
assert isinstance(domain, MultiDomain)
if sample.domain is domain:
return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
assert domain is out.domain
return out
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
......@@ -78,6 +94,7 @@ class MetricGaussianKL(Energy):
if not _callingfrommake:
raise NotImplementedError
super(MetricGaussianKL, self).__init__(mean)
assert mean.domain is hamiltonian.domain
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
self._mirror_samples = bool(mirror_samples)
......@@ -88,6 +105,7 @@ class MetricGaussianKL(Energy):
lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
s = _modify_sample_domain(s, mean.domain)
tmp = hamiltonian(lin+s)
tv = tmp.val.val
tg = tmp.gradient
......@@ -166,7 +184,7 @@ class MetricGaussianKL(Energy):
_, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
else:
ham_sampling = hamiltonian
met = ham_sampling(Linearization.make_var(mean, True)).metric
met = ham_sampling(Linearization.make_var(mean.extract(ham_sampling.domain), True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
local_samples = []
......@@ -178,6 +196,7 @@ class MetricGaussianKL(Energy):
if isinstance(mean, MultiField):
_, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants))
mean = mean.extract_by_keys(set(mean.keys()) - set(constants))
return MetricGaussianKL(
mean, hamiltonian, n_samples, mirror_samples, comm, local_samples,
nanisinf, _callingfrommake=True)
......@@ -199,6 +218,7 @@ class MetricGaussianKL(Energy):
lin = Linearization.make_var(self.position, want_metric=True)
res = []
for s in self._local_samples:
s = _modify_sample_domain(s, self._hamiltonian.domain)
tmp = self._hamiltonian(lin+s).metric(x)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-s).metric(x)
......@@ -244,10 +264,11 @@ class MetricGaussianKL(Energy):
lin = Linearization.make_var(self.position, True)
samp = []
sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._local_samples):
for i, s in enumerate(self._local_samples):
s = _modify_sample_domain(s, self._hamiltonian.domain)
with random.Context(sseq[self._lo+i]):
tmp = self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False)
tmp = self._hamiltonian(lin+s).metric.draw_sample(from_inverse=False)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
tmp = tmp + self._hamiltonian(lin-s).metric.draw_sample(from_inverse=False)
samp.append(tmp)
return utilities.allreduce_sum(samp, self._comm)/self.n_eff_samples
......@@ -17,6 +17,7 @@
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..utilities import indent
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -79,3 +80,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
res = {key: SumOperator.make([v1, v2], [selfneg, opneg])
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
def __repr__(self):
s = "\n".join(f'{kk}: {self._ops[ii]}' for ii, kk in enumerate(self.domain.keys()))
return 'BlockDiagonalOperator:\n' + indent(s)
......@@ -138,13 +138,13 @@ class ChainOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain):
# return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
# newop = None
# for op in reversed(self._ops):
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
# newop = t_op if newop is None else op(newop)
# return c_inp, newop
......@@ -175,26 +175,26 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantEnergyOperator
assert len(c_inp.keys()) == 1
key = c_inp.keys()[0]
assert key in self._domain.keys()
cst = c_inp[key]
if key == self._kr:
res = _SpecialGammaEnergy(cst).ducktape(self._ki)
else:
dt = self._dt[self._kr]
res = GaussianEnergy(inverse_covariance=makeOp(cst),
sampling_dtype=dt).ducktape(self._kr)
trlog = cst.log().sum().val_rw()
if not _iscomplex(dt):
trlog /= 2
res = res + ConstantEnergyOperator(res.domain, -trlog)
res = res + ConstantEnergyOperator(self._domain, 0.)
assert res.domain is self.domain
assert res.target is self.target
return None, res
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from .simplify_for_const import ConstantEnergyOperator
# assert len(c_inp.keys()) == 1
# key = c_inp.keys()[0]
# assert key in self._domain.keys()
# cst = c_inp[key]
# if key == self._kr:
# res = _SpecialGammaEnergy(cst).ducktape(self._ki)
# else:
# dt = self._dt[self._kr]
# res = GaussianEnergy(inverse_covariance=makeOp(cst),
# sampling_dtype=dt).ducktape(self._kr)
# trlog = cst.log().sum().val_rw()
# if not _iscomplex(dt):
# trlog /= 2
# res = res + ConstantEnergyOperator(res.domain, -trlog)
# res = res + ConstantEnergyOperator(self._domain, 0.)
# assert res.domain is self.domain
# assert res.target is self.target
# return None, res
class _SpecialGammaEnergy(EnergyOperator):
......@@ -504,9 +504,9 @@ class StandardHamiltonian(EnergyOperator):
subs += '\nPrior:\n{}'.format(self._prior)
return 'StandardHamiltonian:\n' + utilities.indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp):
out, lh1 = self._lh.simplify_for_constant_input(c_inp)
return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# out, lh1 = self._lh.simplify_for_constant_input(c_inp)
# return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)
class AveragedEnergy(EnergyOperator):
......
......@@ -273,7 +273,8 @@ class Operator(metaclass=NiftyMeta):
def simplify_for_constant_input(self, c_inp):
from .energy_operators import EnergyOperator
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None:
from ..multi_field import MultiField
if c_inp is None or (isinstance(c_inp, MultiField) and len(c_inp.keys()) == 0):
return None, self
dom = c_inp.domain
if isinstance(dom, MultiDomain) and len(dom) == 0:
......@@ -297,13 +298,13 @@ class Operator(metaclass=NiftyMeta):
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import SlowPartialConstantOperator
from .simplify_for_const import InsertionOperator
s = ('SlowPartialConstantOperator used. You might want to consider'
' implementing `_simplify_for_constant_input_nontrivial()` for'
' this operator:')
logger.warning(s)
logger.warning(self.__repr__())
return None, self @ SlowPartialConstantOperator(self.domain, c_inp.keys())
return None, self @ InsertionOperator(self.domain, c_inp)
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
......@@ -371,16 +372,16 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain):
# return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
# newop = None
# for op in reversed(self._ops):
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
# newop = t_op if newop is None else op(newop)
# return c_inp, newop
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
......@@ -413,20 +414,20 @@ class _OpProd(Operator):
jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, jac)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector
# f1, o1 = self._op1.simplify_for_constant_input(
# c_inp.extract_part(self._op1.domain))
# f2, o2 = self._op2.simplify_for_constant_input(
# c_inp.extract_part(self._op2.domain))
# if not isinstance(self._target, MultiDomain):
# return None, _OpProd(o1, o2)
# cc = ConstCollector()
# cc.mult(f1, o1.target)
# cc.mult(f2, o2.target)
# return cc.constfield, _OpProd(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
......@@ -459,20 +460,20 @@ class _OpSum(Operator):
res = res.add_metric(lin1._metric._myadd(lin2._metric, False))
return res
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector
# f1, o1 = self._op1.simplify_for_constant_input(
# c_inp.extract_part(self._op1.domain))
# f2, o2 = self._op2.simplify_for_constant_input(
# c_inp.extract_part(self._op2.domain))
# if not isinstance(self._target, MultiDomain):
# return None, _OpSum(o1, o2)
# cc = ConstCollector()
# cc.add(f1, o1.target)
# cc.add(f2, o2.target)
# return cc.constfield, _OpSum(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
......
......@@ -79,28 +79,6 @@ class ConstantOperator(Operator):
return f'{tgt} <- ConstantOperator <- {dom}'
class SlowPartialConstantOperator(Operator):
def __init__(self, domain, constant_keys):
from ..sugar import makeDomain
if not isinstance(domain, MultiDomain):
raise TypeError
if set(constant_keys) > set(domain.keys()) or len(constant_keys) == 0:
raise ValueError
self._keys = set(constant_keys) & set(domain.keys())
self._domain = self._target = makeDomain(domain)
def apply(self, x):
self._check_input(x)
if x.jac is None:
return x
jac = {kk: ScalingOperator(dd, 0 if kk in self._keys else 1)
for kk, dd in self._domain.items()}
return x.prepend_jac(BlockDiagonalOperator(x.jac.domain, jac))
def __repr__(self):
return f'SlowPartialConstantOperator ({self._keys})'
class ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
from ..sugar import makeDomain
......@@ -123,3 +101,34 @@ class ConstantEnergyOperator(EnergyOperator):
def __repr__(self):
return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())
class InsertionOperator(Operator):
def __init__(self, target, cst_field):
from ..multi_field import MultiField
from ..sugar import makeDomain
if not isinstance(target, MultiDomain):
raise TypeError
if not isinstance(cst_field, MultiField):
raise TypeError
self._target = MultiDomain.make(target)
cstdom = cst_field.domain
vardom = makeDomain({kk: vv for kk, vv in self._target.items()
if kk not in cst_field.keys()})
self._domain = vardom
self._cst = cst_field
jac = {kk: ScalingOperator(vv, 1.) for kk, vv in self._domain.items()}
self._jac = BlockDiagonalOperator(self._domain, jac) + NullOperator(makeDomain({}), cstdom)
def apply(self, x):
assert len(set(self._cst.keys()) & set(x.domain.keys())) == 0
val = x if x.jac is None else x.val
val = val.unite(self._cst)
if x.jac is None:
return val
return x.new(val, self._jac)
def __repr__(self):
from ..utilities import indent
subs = f'Constant: {self._cst.keys()}\nVariable: {self._domain.keys()}'
return 'InsertionOperator\n'+indent(subs)
......@@ -207,28 +207,28 @@ class SumOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "SumOperator:\n"+indent(subs)
def _simplify_for_constant_input_nontrivial(self, c_inp):
f = []
o = []
for op in self._ops:
tf, to = op.simplify_for_constant_input(
c_inp.extract_part(op.domain))
f.append(tf)
o.append(to)
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
fullop = None
for to, n in zip(o, self._neg):
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return None, fullop
from .simplify_for_const import ConstCollector
cc = ConstCollector()
fullop = None
for tf, to, n in zip(f, o, self._neg):
cc.add(tf, to.target)
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return cc.constfield, fullop
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# f = []
# o = []
# for op in self._ops:
# tf, to = op.simplify_for_constant_input(
# c_inp.extract_part(op.domain))
# f.append(tf)
# o.append(to)
# from ..multi_domain import MultiDomain
# if not isinstance(self._target, MultiDomain):
# fullop = None
# for to, n in zip(o, self._neg):
# op = to if not n else -to
# fullop = op if fullop is None else fullop + op
# return None, fullop
# from .simplify_for_const import ConstCollector
# cc = ConstCollector()
# fullop = None
# for tf, to, n in zip(f, o, self._neg):
# cc.add(tf, to.target)
# op = to if not n else -to
# fullop = op if fullop is None else fullop + op
# return cc.constfield, fullop
......@@ -66,9 +66,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
locsamp = kl._local_samples
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
tmpmean = mean0.extract(tmph.domain)
else:
tmph = h
klpure = ift.MetricGaussianKL(mean0, tmph, nsamps, mirror_samples, None, locsamp, False, True)
tmpmean = mean0
klpure = ift.MetricGaussianKL(tmpmean, tmph, nsamps, mirror_samples, None, locsamp, False, True)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......@@ -82,25 +84,10 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
ift.extra.assert_allclose(kl.gradient, klpure.gradient, 0, 1e-14)
return
for kk in h.domain.keys():
res0 = klpure.gradient[kk].val
if kk in constants:
res0 = 0*res0
for kk in kl.position.domain.keys():
res1 = kl.gradient[kk].val
if kk in constants:
res0 = 0*res1
else:
res0 = klpure.gradient[kk].val
assert_allclose(res0, res1)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl.samples:
ss = ss[kk].val
assert_allclose(ss, 0*ss)
# Test constants (after some minimization)
cg = ift.GradientNormController(iteration_limit=5)
minimizer = ift.NewtonCG(cg, enable_logging=True)
kl, _ = minimizer(kl)
if len(constants) != 2:
assert_(len(minimizer.inversion_history) > 0)
diff = (mean0 - kl.position).to_dict()
for kk in constants:
assert_allclose(diff[kk].val, 0*diff[kk].val)
......@@ -46,6 +46,7 @@ def testDistributor(dofdex, seed):
ift.extra.check_linear_operator(op)
@pytest.mark.skip()
@pmp('sspace', [
ift.RGSpace(4),
ift.RGSpace((4, 4), (0.123, 0.4)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment