Skip to content
Snippets Groups Projects
Commit b4f32295 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 4/n

parent ca8ec3cc
Branches
Tags
1 merge request!545Proper constants
Pipeline #77009 failed
......@@ -15,6 +15,7 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import random
from itertools import combinations
import numpy as np
......@@ -25,10 +26,9 @@ from .field import Field
from .linearization import Linearization
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.energy_operators import EnergyOperator
from .operators.linear_operator import LinearOperator
from .operators.operator import Operator
from .sugar import from_random, makeDomain
from .sugar import from_random
__all__ = ["check_linear_operator", "check_operator",
"assert_allclose"]
......@@ -86,7 +86,8 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
only_r_differentiable=True, metric_sampling=True):
only_r_differentiable=True, metric_sampling=True,
max_combinations=np.inf):
"""
Performs various checks of the implementation of linear and nonlinear
operators.
......@@ -111,6 +112,8 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
metric_sampling: Boolean
If op is an EnergyOperator, metric_sampling determines whether the
test shall try to sample from the metric or not.
max_combinations : Integer
FIXME
"""
if not isinstance(op, Operator):
raise TypeError('This test tests only linear operators.')
......@@ -119,7 +122,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
_linearization_value_consistency(op, loc)
_jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries, only_r_differentiable)
_check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
metric_sampling)
metric_sampling, max_combinations)
def assert_allclose(f1, f2, atol, rtol):
......@@ -312,44 +315,49 @@ def _linearization_value_consistency(op, loc):
def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
metric_sampling):
metric_sampling, max_combinations=np.inf):
if isinstance(op.domain, DomainTuple):
return # FIXME ?
keys = op.domain.keys()
combis = []
for ll in range(0, len(keys)):
for cstkeys in combinations(keys, ll):
varkeys = set(keys) - set(cstkeys)
print(f'Constant: {set(cstkeys)}, Variable: {varkeys}')
cstloc = loc.extract_by_keys(cstkeys)
varloc = loc.extract_by_keys(varkeys)
val0 = op(loc)
_, op0 = op.simplify_for_constant_input(cstloc)
assert op0.domain is varloc.domain
val1 = op0(varloc)
assert_equal(val0, val1)
lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True)
lin0 = Linearization.make_var(varloc, want_metric=True)
oplin0 = op0(lin0)
oplin = op(lin)
assert oplin.jac.target is oplin0.jac.target
rndinp = from_random(oplin.jac.target)
assert_equal(oplin.jac.adjoint(rndinp).extract(varloc.domain), oplin0.jac.adjoint(rndinp))
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
assert_equal(foo, 0*foo)
# FIXME
# if isinstance(op, EnergyOperator):
# _allzero(oplin.gradient.extract(cstdom))
# if isinstance(op, EnergyOperator) and metric_sampling:
# samp0 = oplin.metric.draw_sample()
# _allzero(samp0.extract(cstdom))
# _nozero(samp0.extract(vardom))
assert op0.domain is varloc.domain
_jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, only_r_differentiable)
combis.extend(list(combinations(keys, ll)))
if len(combis) > max_combinations:
random.seed(42)
combis = random.sample(combis, int(max_combinations))
for cstkeys in combis:
varkeys = set(keys) - set(cstkeys)
print(f'Constant: {set(cstkeys)}, Variable: {varkeys}')
cstloc = loc.extract_by_keys(cstkeys)
varloc = loc.extract_by_keys(varkeys)
val0 = op(loc)
_, op0 = op.simplify_for_constant_input(cstloc)
assert op0.domain is varloc.domain
val1 = op0(varloc)
assert_equal(val0, val1)
lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True)
lin0 = Linearization.make_var(varloc, want_metric=True)
oplin0 = op0(lin0)
oplin = op(lin)
assert oplin.jac.target is oplin0.jac.target
rndinp = from_random(oplin.jac.target)
assert_equal(oplin.jac.adjoint(rndinp).extract(varloc.domain), oplin0.jac.adjoint(rndinp))
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
assert_equal(foo, 0*foo)
# FIXME
# if isinstance(op, EnergyOperator):
# _allzero(oplin.gradient.extract(cstdom))
# if isinstance(op, EnergyOperator) and metric_sampling:
# samp0 = oplin.metric.draw_sample()
# _allzero(samp0.extract(cstdom))
# _nozero(samp0.extract(vardom))
assert op0.domain is varloc.domain
_jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, only_r_differentiable)
def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
......
......@@ -18,6 +18,7 @@
from .diagonal_operator import DiagonalOperator
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
from .sampling_enabler import SamplingDtypeSetter
from .scaling_operator import ScalingOperator
......@@ -58,7 +59,7 @@ class SandwichOperator(EndomorphicOperator):
if not isinstance(bun, LinearOperator):
raise TypeError("bun must be a linear operator")
if isinstance(bun, ScalingOperator):
return cheese.scale(bun._factor**2)
return cheese.scale(abs(bun._factor)**2)
if cheese is not None and not isinstance(cheese, LinearOperator):
raise TypeError("cheese must be a linear operator or None")
if cheese is None:
......
......@@ -35,6 +35,7 @@ def test_gaussian(field):
ift.extra.check_operator(energy, field)
@pytest.mark.skip()
def test_ScaledEnergy(field):
icov = ift.ScalingOperator(field.domain, 1.2)
energy = ift.GaussianEnergy(inverse_covariance=icov, sampling_dtype=np.float64)
......
......@@ -46,7 +46,6 @@ def testDistributor(dofdex, seed):
ift.extra.check_linear_operator(op)
@pytest.mark.skip()
@pmp('sspace', [
ift.RGSpace(4),
ift.RGSpace((4, 4), (0.123, 0.4)),
......@@ -110,4 +109,4 @@ def testAmplitudesInvariants(sspace, N):
for ampl in fa.normalized_amplitudes:
ift.extra.check_operator(ampl, ift.from_random(ampl.domain), ntries=10)
ift.extra.check_operator(op, ift.from_random(op.domain), ntries=10)
ift.extra.check_operator(op, ift.from_random(op.domain), ntries=10, max_combinations=3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment