Commit dc9494c9 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'constant_support_pa' into 'NIFTy_7'

Proper constants

See merge request !545
parents 04c52785 3aeba77e
Pipeline #77081 passed with stages
in 13 minutes and 18 seconds
......@@ -20,6 +20,7 @@ from itertools import combinations
import numpy as np
from numpy.testing import assert_
from . import random
from .domain_tuple import DomainTuple
from .field import Field
from .linearization import Linearization
......@@ -28,7 +29,7 @@ from .multi_field import MultiField
from .operators.energy_operators import EnergyOperator
from .operators.linear_operator import LinearOperator
from .operators.operator import Operator
from .sugar import from_random, makeDomain
from .sugar import from_random
__all__ = ["check_linear_operator", "check_operator",
"assert_allclose"]
......@@ -117,7 +118,8 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
_domain_check_nonlinear(op, loc)
_performance_check(op, loc, bool(perf_check))
_linearization_value_consistency(op, loc)
_jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries, only_r_differentiable)
_jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries,
only_r_differentiable)
_check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
metric_sampling)
......@@ -313,43 +315,46 @@ def _linearization_value_consistency(op, loc):
def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
metric_sampling):
return # FIXME
# Assumes that the operator is not constant
if isinstance(op.domain, DomainTuple):
return
keys = op.domain.keys()
for ll in range(0, len(keys)):
for cstkeys in combinations(keys, ll):
cstdom, vardom = {}, {}
for kk, dd in op.domain.items():
if kk in cstkeys:
cstdom[kk] = dd
else:
vardom[kk] = dd
cstdom, vardom = makeDomain(cstdom), makeDomain(vardom)
cstloc = loc.extract(cstdom)
val0 = op(loc)
_, op0 = op.simplify_for_constant_input(cstloc)
val1 = op0(loc)
# MR FIXME: This tests something we don't promise!
# val2 = op0(loc.unite(cstloc))
# assert_equal(val1, val2)
assert_equal(val0, val1)
lin = Linearization.make_var(loc, want_metric=True)
oplin = op0(lin)
if isinstance(op, EnergyOperator):
_allzero(oplin.gradient.extract(cstdom))
# MR FIXME: This tests something we don't promise!
# _allzero(oplin.jac(from_random(cstdom).unite(full(vardom, 0))))
if isinstance(op, EnergyOperator) and metric_sampling:
samp0 = oplin.metric.draw_sample()
_allzero(samp0.extract(cstdom))
_nozero(samp0.extract(vardom))
_jac_vs_finite_differences(op0, loc, np.sqrt(tol), ntries, only_r_differentiable)
combis = []
if len(keys) > 4:
from .logger import logger
logger.warning('Operator domain has more than 4 keys.')
logger.warning('Check derivatives only with one constant key at a time.')
combis = [[kk] for kk in keys]
else:
for ll in range(1, len(keys)):
combis.extend(list(combinations(keys, ll)))
for cstkeys in combis:
varkeys = set(keys) - set(cstkeys)
cstloc = loc.extract_by_keys(cstkeys)
varloc = loc.extract_by_keys(varkeys)
val0 = op(loc)
_, op0 = op.simplify_for_constant_input(cstloc)
assert op0.domain is varloc.domain
val1 = op0(varloc)
assert_equal(val0, val1)
lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True)
lin0 = Linearization.make_var(varloc, want_metric=True)
oplin0 = op0(lin0)
oplin = op(lin)
assert oplin.jac.target is oplin0.jac.target
rndinp = from_random(oplin.jac.target)
assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain),
oplin0.jac.adjoint(rndinp), 1e-13, 1e-13)
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
assert_equal(foo, 0*foo)
if isinstance(op, EnergyOperator) and metric_sampling:
oplin.metric.draw_sample()
# _jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries,
# only_r_differentiable)
def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
......@@ -379,4 +384,5 @@ def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
loc = locnext
check_linear_operator(linmid.jac, domain_dtype=loc.dtype,
target_dtype=dirder.dtype,
only_r_linear=only_r_differentiable)
only_r_linear=only_r_differentiable,
atol=tol**2, rtol=tol**2)
......@@ -47,6 +47,27 @@ def _get_lo_hi(comm, n_samples):
return utilities.shareRange(n_samples, ntask, rank)
def _modify_sample_domain(sample, domain):
"""Takes only keys from sample which are also in domain and inserts zeros
for keys which are not in sample.domain."""
from ..multi_domain import MultiDomain
from ..field import Field
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
domain = makeDomain(domain)
if isinstance(domain, DomainTuple) and isinstance(sample, Field):
if sample.domain is not domain:
raise TypeError
return sample
elif isinstance(domain, MultiDomain) and isinstance(sample, MultiField):
if sample.domain is domain:
return sample
out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
out = MultiField.from_dict(out, domain)
return out
raise TypeError
class MetricGaussianKL(Energy):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
......@@ -78,6 +99,7 @@ class MetricGaussianKL(Energy):
if not _callingfrommake:
raise NotImplementedError
super(MetricGaussianKL, self).__init__(mean)
assert mean.domain is hamiltonian.domain
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
self._mirror_samples = bool(mirror_samples)
......@@ -88,6 +110,7 @@ class MetricGaussianKL(Energy):
lin = Linearization.make_var(mean)
v, g = [], []
for s in self._local_samples:
s = _modify_sample_domain(s, mean.domain)
tmp = hamiltonian(lin+s)
tv = tmp.val.val
tg = tmp.gradient
......@@ -166,7 +189,8 @@ class MetricGaussianKL(Energy):
_, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
else:
ham_sampling = hamiltonian
met = ham_sampling(Linearization.make_var(mean, True)).metric
lin = Linearization.make_var(mean.extract(ham_sampling.domain), True)
met = ham_sampling(lin).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
local_samples = []
......@@ -178,6 +202,7 @@ class MetricGaussianKL(Energy):
if isinstance(mean, MultiField):
_, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants))
mean = mean.extract_by_keys(set(mean.keys()) - set(constants))
return MetricGaussianKL(
mean, hamiltonian, n_samples, mirror_samples, comm, local_samples,
nanisinf, _callingfrommake=True)
......@@ -199,6 +224,7 @@ class MetricGaussianKL(Energy):
lin = Linearization.make_var(self.position, want_metric=True)
res = []
for s in self._local_samples:
s = _modify_sample_domain(s, self._hamiltonian.domain)
tmp = self._hamiltonian(lin+s).metric(x)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-s).metric(x)
......@@ -244,10 +270,11 @@ class MetricGaussianKL(Energy):
lin = Linearization.make_var(self.position, True)
samp = []
sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._local_samples):
for i, s in enumerate(self._local_samples):
s = _modify_sample_domain(s, self._hamiltonian.domain)
with random.Context(sseq[self._lo+i]):
tmp = self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False)
tmp = self._hamiltonian(lin+s).metric.draw_sample(from_inverse=False)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
tmp = tmp + self._hamiltonian(lin-s).metric.draw_sample(from_inverse=False)
samp.append(tmp)
return utilities.allreduce_sum(samp, self._comm)/self.n_eff_samples
......@@ -17,6 +17,7 @@
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..utilities import indent
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -79,3 +80,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
res = {key: SumOperator.make([v1, v2], [selfneg, opneg])
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
def __repr__(self):
s = "\n".join(f'{kk}: {self._ops[ii]}' for ii, kk in enumerate(self.domain.keys()))
return 'BlockDiagonalOperator:\n' + indent(s)
......@@ -58,7 +58,14 @@ class ChainOperator(LinearOperator):
fct = 1.
opsnew = []
lastdom = ops[-1].domain
dtype = None
for op in ops:
from .sampling_enabler import SamplingDtypeSetter
if isinstance(op, SamplingDtypeSetter) and isinstance(op._op, ScalingOperator):
if dtype is not None:
raise NotImplementedError
dtype = op._dtype
op = op._op
if (isinstance(op, ScalingOperator) and op._factor.imag == 0):
fct *= op._factor.real
else:
......@@ -72,7 +79,10 @@ class ChainOperator(LinearOperator):
break
if fct != 1 or len(opsnew) == 0:
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(lastdom, fct))
op = ScalingOperator(lastdom, fct)
if dtype is not None:
op = SamplingDtypeSetter(op, dtype)
opsnew.append(op)
ops = opsnew
# combine DiagonalOperators where possible
opsnew = []
......@@ -142,7 +152,6 @@ class ChainOperator(LinearOperator):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
......
......@@ -190,9 +190,8 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
trlog = cst.log().sum().val_rw()
if not _iscomplex(dt):
trlog /= 2
res = res + ConstantEnergyOperator(res.domain, -trlog)
res = res + ConstantEnergyOperator(self._domain, 0.)
assert res.domain is self.domain
res = res + ConstantEnergyOperator(-trlog)
res = res + ConstantEnergyOperator(0.)
assert res.target is self.target
return None, res
......@@ -491,11 +490,9 @@ class StandardHamiltonian(EnergyOperator):
`<https://arxiv.org/abs/1812.04403>`_
"""
def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64):
def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
self._lh = lh
self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
if _c_inp is not None:
_, self._prior = self._prior.simplify_for_constant_input(_c_inp)
self._ic_samp = ic_samp
self._domain = lh.domain
......@@ -513,7 +510,7 @@ class StandardHamiltonian(EnergyOperator):
def _simplify_for_constant_input_nontrivial(self, c_inp):
out, lh1 = self._lh.simplify_for_constant_input(c_inp)
return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)
return out, StandardHamiltonian(lh1, self._ic_samp)
class AveragedEnergy(EnergyOperator):
......
......@@ -273,7 +273,10 @@ class Operator(metaclass=NiftyMeta):
def simplify_for_constant_input(self, c_inp):
from .energy_operators import EnergyOperator
from .simplify_for_const import ConstantEnergyOperator, ConstantOperator
if c_inp is None:
from ..multi_field import MultiField
from ..domain_tuple import DomainTuple
from ..sugar import makeDomain
if c_inp is None or (isinstance(c_inp, MultiField) and len(c_inp.keys()) == 0):
return None, self
dom = c_inp.domain
if isinstance(dom, MultiDomain) and len(dom) == 0:
......@@ -283,27 +286,36 @@ class Operator(metaclass=NiftyMeta):
# subdomain of self._domain
if isinstance(self.domain, MultiDomain):
assert isinstance(dom, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
if not set(c_inp.keys()) <= set(self.domain.keys()):
raise ValueError
if dom is self.domain:
if isinstance(self, DomainTuple):
raise RuntimeError
if isinstance(self, EnergyOperator):
op = ConstantEnergyOperator(self.domain, self(c_inp))
op = ConstantEnergyOperator(self(c_inp))
else:
op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
op = ConstantOperator(self(c_inp))
return None, op
if not isinstance(dom, MultiDomain):
raise RuntimeError
return self._simplify_for_constant_input_nontrivial(c_inp)
c_out, op = self._simplify_for_constant_input_nontrivial(c_inp)
vardom = makeDomain({kk: vv for kk, vv in self.domain.items()
if kk not in c_inp.keys()})
assert op.domain is vardom
assert op.target is self.target
assert isinstance(op, Operator)
if c_out is not None:
assert isinstance(c_out, MultiField)
assert len(set(c_out.keys()) & self.domain.keys()) == 0
assert set(c_out.keys()) <= set(c_inp.keys())
return c_out, op
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import SlowPartialConstantOperator
s = ('SlowPartialConstantOperator used. You might want to consider'
' implementing `_simplify_for_constant_input_nontrivial()` for'
' this operator:')
logger.warning(s)
from .simplify_for_const import InsertionOperator
logger.warning('SlowPartialConstantOperator used for:')
logger.warning(self.__repr__())
return None, self @ SlowPartialConstantOperator(self.domain, c_inp.keys())
return None, self @ InsertionOperator(self.domain, c_inp)
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
......@@ -375,7 +387,6 @@ class _OpChain(_CombinedOperator):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
......@@ -416,7 +427,6 @@ class _OpProd(Operator):
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
......@@ -462,7 +472,6 @@ class _OpSum(Operator):
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
......
......@@ -15,11 +15,10 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from .diagonal_operator import DiagonalOperator
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .sampling_enabler import SamplingDtypeSetter
from .scaling_operator import ScalingOperator
......@@ -56,11 +55,15 @@ class SandwichOperator(EndomorphicOperator):
old_cheese = cheese
cheese = old_cheese._cheese
bun = old_cheese._bun @ bun
if not isinstance(bun, LinearOperator):
raise TypeError("bun must be a linear operator")
if isinstance(bun, ScalingOperator):
return cheese.scale(abs(bun._factor)**2)
if cheese is not None and not isinstance(cheese, LinearOperator):
raise TypeError("cheese must be a linear operator or None")
if cheese is None:
# FIXME Sampling dtype not clear in this case
cheese = ScalingOperator(bun.target, 1.)
op = bun.adjoint(bun)
else:
......
......@@ -334,9 +334,9 @@ class NullOperator(LinearOperator):
@staticmethod
def _nullfield(dom):
if isinstance(dom, DomainTuple):
return Field(dom, 0)
return Field(dom, 0.)
else:
return MultiField.full(dom, 0)
return MultiField.full(dom, 0.)
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -60,10 +60,10 @@ class ConstCollector(object):
class ConstantOperator(Operator):
def __init__(self, dom, output):
def __init__(self, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
self._target = output.domain
self._domain = makeDomain({})
self._target = makeDomain(output.domain)
self._output = output
def apply(self, x):
......@@ -74,42 +74,17 @@ class ConstantOperator(Operator):
return self._output
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- ConstantOperator <- {dom}'
class SlowPartialConstantOperator(Operator):
def __init__(self, domain, constant_keys):
from ..sugar import makeDomain
if not isinstance(domain, MultiDomain):
raise TypeError
if set(constant_keys) > set(domain.keys()) or len(constant_keys) == 0:
raise ValueError
self._keys = set(constant_keys) & set(domain.keys())
self._domain = self._target = makeDomain(domain)
def apply(self, x):
self._check_input(x)
if x.jac is None:
return x
jac = {kk: ScalingOperator(dd, 0 if kk in self._keys else 1)
for kk, dd in self._domain.items()}
return x.prepend_jac(BlockDiagonalOperator(x.jac.domain, jac))
def __repr__(self):
return f'SlowPartialConstantOperator ({self._keys})'
return f'{tgt} <- ConstantOperator'
class ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
def __init__(self, output):
from ..sugar import makeDomain
from ..field import Field
self._domain = makeDomain(dom)
self._domain = makeDomain({})
if not isinstance(output, Field):
output = Field.scalar(float(output))
if self.target is not output.domain:
raise TypeError
self._output = output
def apply(self, x):
......@@ -117,9 +92,38 @@ class ConstantEnergyOperator(EnergyOperator):
if x.jac is not None:
val = self._output
jac = NullOperator(self._domain, self._target)
# FIXME Do we need a metric here?
met = NullOperator(self._domain, self._domain) if x.want_metric else None
return x.new(val, jac, met)
return self._output
class InsertionOperator(Operator):
def __init__(self, target, cst_field):
from ..multi_field import MultiField
from ..sugar import makeDomain
if not isinstance(target, MultiDomain):
raise TypeError
if not isinstance(cst_field, MultiField):
raise TypeError
self._target = MultiDomain.make(target)
cstdom = cst_field.domain
vardom = makeDomain({kk: vv for kk, vv in self._target.items()
if kk not in cst_field.keys()})
self._domain = vardom
self._cst = cst_field
jac = {kk: ScalingOperator(vv, 1.) for kk, vv in self._domain.items()}
self._jac = BlockDiagonalOperator(self._domain, jac) + NullOperator(makeDomain({}), cstdom)
def apply(self, x):
assert len(set(self._cst.keys()) & set(x.domain.keys())) == 0
val = x if x.jac is None else x.val
val = val.unite(self._cst)
if x.jac is None:
return val
return x.new(val, self._jac)
def __repr__(self):
return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())
from ..utilities import indent
subs = f'Constant: {self._cst.keys()}\nVariable: {self._domain.keys()}'
return 'InsertionOperator\n'+indent(subs)
......@@ -3,7 +3,7 @@
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__ = '6.0.0'
__version__ = '7.0'
def gitversion():
......
......@@ -85,41 +85,4 @@ def testgaussianenergy_compatibility(cplx):
loc0 = ift.MultiField.from_dict({'resi': resi})
loc1 = ift.MultiField.from_dict({'icov': ift.from_random(dom).exp()})
loc = loc0.unite(loc1)
val0 = e(loc).val
_, e0 = e.simplify_for_constant_input(loc0)
val1 = e0(loc).val
val2 = e0(loc.unite(loc0)).val
np.testing.assert_equal(val1, val2)
np.testing.assert_equal(val0, val1)
_, e1 = e.simplify_for_constant_input(loc1)
val1 = e1(loc).val
val2 = e1(loc.unite(loc1)).val
np.testing.assert_equal(val0, val1)
np.testing.assert_equal(val1, val2)
ift.extra.check_operator(e, loc, ntries=ntries)
ift.extra.check_operator(e0, loc, ntries=ntries, tol=1e-7)
ift.extra.check_operator(e1, loc, ntries=ntries)
# Test jacobian is zero
lin = ift.Linearization.make_var(loc, want_metric=True)
grad = e(lin).gradient.val
grad0 = e0(lin).gradient.val
grad1 = e1(lin).gradient.val
samp = e(lin).metric.draw_sample().val
samp0 = e0(lin).metric.draw_sample().val
samp1 = e1(lin).metric.draw_sample().val
np.testing.assert_equal(samp0['resi'], 0.)
np.testing.assert_equal(samp1['icov'], 0.)
np.testing.assert_equal(grad0['resi'], 0.)
np.testing.assert_equal(grad1['icov'], 0.)
np.testing.assert_(all(samp['resi'] != 0))
np.testing.assert_(all(samp['icov'] != 0))
np.testing.assert_(all(samp0['icov'] != 0))
np.testing.assert_(all(samp1['resi'] != 0))
np.testing.assert_(all(grad['resi'] != 0))
np.testing.assert_(all(grad['icov'] != 0))
np.testing.assert_(all(grad0['icov'] != 0))
np.testing.assert_(all(grad1['resi'] != 0))
ift.extra.check_operator(e, loc, ntries=20)
......@@ -11,10 +11,11 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import pytest
from numpy.testing import assert_, assert_allclose, assert_raises
......@@ -36,7 +37,6 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
op = ift.HarmonicSmoothingOperator(dom, 3)
if mf:
op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
import numpy as np
lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op
ic = ift.GradientNormController(iteration_limit=5)
ic.enable_logging()
......@@ -66,9 +66,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
locsamp = kl._local_samples
if isinstance(mean0, ift.MultiField):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
tmpmean = mean0.extract(tmph.domain)
else:
tmph = h
klpure = ift.MetricGaussianKL(mean0, tmph, nsamps, mirror_samples, None, locsamp, False, True)
tmpmean = mean0
klpure = ift.MetricGaussianKL(tmpmean, tmph, nsamps, mirror_samples, None, locsamp, False, True)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......@@ -82,25 +84,10 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
ift.extra.assert_allclose(kl.gradient, klpure.gradient, 0, 1e-14)
return
for kk in h.domain.keys():
res0 = klpure.gradient[kk].val
if kk in constants:
res0 = 0*res0
for kk in kl.position.domain.keys():
res1 = kl.gradient[kk].val
if kk in constants:
res0 = 0*res1
else:
res0 = klpure.gradient[kk].val
assert_allclose(res0, res1)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl.samples:
ss = ss[kk].val
assert_allclose(ss, 0*ss)
# Test constants (after some minimization)
cg = ift.GradientNormController(iteration_limit=5)
minimizer = ift.NewtonCG(cg, enable_logging=True)
kl, _ = minimizer(kl)
if len(constants) != 2:
assert_(len(minimizer.inversion_history) > 0)