Commit 9dea1d88 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 5/n

parent b4f32295
......@@ -26,6 +26,7 @@ from .field import Field
from .linearization import Linearization
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.energy_operators import EnergyOperator
from .operators.linear_operator import LinearOperator
from .operators.operator import Operator
from .sugar import from_random
......@@ -320,14 +321,13 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
return # FIXME ?
keys = op.domain.keys()
combis = []
for ll in range(0, len(keys)):
for ll in range(1, len(keys)):
combis.extend(list(combinations(keys, ll)))
if len(combis) > max_combinations:
random.seed(42)
combis = random.sample(combis, int(max_combinations))
for cstkeys in combis:
varkeys = set(keys) - set(cstkeys)
print(f'Constant: {set(cstkeys)}, Variable: {varkeys}')
cstloc = loc.extract_by_keys(cstkeys)
varloc = loc.extract_by_keys(varkeys)
......@@ -348,13 +348,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
assert_equal(foo, 0*foo)
# FIXME
# if isinstance(op, EnergyOperator):
# _allzero(oplin.gradient.extract(cstdom))
# if isinstance(op, EnergyOperator) and metric_sampling:
# samp0 = oplin.metric.draw_sample()
# _allzero(samp0.extract(cstdom))
# _nozero(samp0.extract(vardom))
if isinstance(op, EnergyOperator) and metric_sampling:
oplin.metric.draw_sample()
assert op0.domain is varloc.domain
_jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, only_r_differentiable)
......
......@@ -302,11 +302,7 @@ class Operator(metaclass=NiftyMeta):
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import InsertionOperator
s = ('SlowPartialConstantOperator used. You might want to consider'
' implementing `_simplify_for_constant_input_nontrivial()` for'
' this operator:')
logger.warning(s)
logger.warning(self.__repr__())
logger.warning('SlowPartialConstantOperator used.')
return None, self @ InsertionOperator(self.domain, c_inp)
def ptw(self, op, *args, **kwargs):
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......
......@@ -69,7 +69,7 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True)
kl1 = ift.MetricGaussianKL(mean0.extract(tmph.domain), tmph, 2, mirror_samples, comm, locsamp, False, True)
elif mode == 1:
kl0 = ift.MetricGaussianKL.make(**args)
samples = kl0._local_samples
......@@ -80,7 +80,7 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
_, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
else:
tmph = h
kl1 = ift.MetricGaussianKL(mean0, tmph, 2, mirror_samples, comm, locsamp, False, True)
kl1 = ift.MetricGaussianKL(mean0.extract(tmph.domain), tmph, 2, mirror_samples, comm, locsamp, False, True)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
......@@ -92,31 +92,9 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
# Test gradient
if mf:
for kk in h.domain.keys():
for kk in kl0.gradient.domain.keys():
res0 = kl0.gradient[kk].val
if kk in constants:
res0 = 0*res0
res1 = kl1.gradient[kk].val
assert_equal(res0, res1)
else:
assert_equal(kl0.gradient.val, kl1.gradient.val)
# Test point_estimates (after drawing samples)
if mf:
for kk in point_estimates:
for ss in kl0.samples:
ss = ss[kk].val
assert_equal(ss, 0*ss)
for ss in kl1.samples:
ss = ss[kk].val
assert_equal(ss, 0*ss)
# Test constants (after some minimization)
if mf:
cg = ift.GradientNormController(iteration_limit=5)
minimizer = ift.NewtonCG(cg)
for e in [kl0, kl1]:
e, _ = minimizer(e)
diff = (mean0 - e.position).to_dict()
for kk in constants:
assert_equal(diff[kk].val, 0*diff[kk].val)
......@@ -108,5 +108,7 @@ def testAmplitudesInvariants(sspace, N):
assert_(op.target[-1] == fsspace)
for ampl in fa.normalized_amplitudes:
ift.extra.check_operator(ampl, ift.from_random(ampl.domain), ntries=10)
ift.extra.check_operator(op, ift.from_random(op.domain), ntries=10, max_combinations=3)
ift.extra.check_operator(ampl, ift.from_random(ampl.domain),
ntries=3, max_combinations=5)
ift.extra.check_operator(op, ift.from_random(op.domain),
ntries=3, max_combinations=5)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment