Commit 254dfe98 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'constantsupport' into 'NIFTy_7'

Constantsupport

See merge request !529
parents 46f2d697 8205d1a9
Pipeline #76161 passed with stages
in 13 minutes and 35 seconds
......@@ -37,6 +37,13 @@ def assert_allclose(f1, f2, atol, rtol):
assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def assert_equal(f1, f2):
if isinstance(f1, Field):
return np.testing.assert_equal(f1.val, f2.val)
for key, val in f1.items():
assert_equal(val, f2[key])
def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear):
needed_cap = op.TIMES | op.ADJOINT_TIMES
......@@ -249,7 +256,7 @@ def _linearization_value_consistency(op, loc):
def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True,
only_r_differentiable=True):
only_r_differentiable=True):
"""
Checks the Jacobian of an operator against its finite difference
approximation.
......
......@@ -46,7 +46,6 @@ class BlockDiagonalOperator(EndomorphicOperator):
else:
raise TypeError("LinearOperator expected")
def apply(self, x, mode):
self._check_input(x, mode)
val = tuple(op.apply(v, mode=mode) if op is not None else v
......
......@@ -48,6 +48,10 @@ def _check_sampling_dtype(domain, dtypes):
raise TypeError
def _iscomplex(dtype):
return np.issubdtype(dtype, np.complexfloating)
def _field_to_dtype(field):
if isinstance(field, Field):
dt = field.dtype
......@@ -127,10 +131,10 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
The covariance is assumed to be diagonal.
.. math ::
E(s,D) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
an information energy for a Gaussian distribution with residual s and
diagonal covariance D.
inverse diagonal covariance C.
The domain of this energy will be a MultiDomain with two keys,
the target will be the scalar domain.
......@@ -139,10 +143,10 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
domain : Domain, DomainTuple, tuple of Domain
domain of the residual and domain of the covariance diagonal.
residual : key
residual_key : key
Residual key of the Gaussian.
inverse_covariance : key
inverse_covariance_key : key
Inverse covariance diagonal key of the Gaussian.
sampling_dtype : np.dtype
......@@ -156,7 +160,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, self._dt)
self._cplx = np.issubdtype(sampling_dtype, np.complexfloating)
self._cplx = _iscomplex(sampling_dtype)
def apply(self, x):
self._check_input(x)
......@@ -171,6 +175,47 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantEnergyOperator
assert len(c_inp.keys()) == 1
key = c_inp.keys()[0]
assert key in self._domain.keys()
cst = c_inp[key]
if key == self._kr:
res = _SpecialGammaEnergy(cst).ducktape(self._ki)
else:
dt = self._dt[self._kr]
res = GaussianEnergy(inverse_covariance=makeOp(cst),
sampling_dtype=dt).ducktape(self._kr)
trlog = cst.log().sum().val_rw()
if not _iscomplex(dt):
trlog /= 2
res = res + ConstantEnergyOperator(res.domain, -trlog)
res = res + ConstantEnergyOperator(self._domain, 0.)
assert res.domain is self.domain
assert res.target is self.target
return None, res
class _SpecialGammaEnergy(EnergyOperator):
def __init__(self, residual):
self._domain = DomainTuple.make(residual.domain)
self._resi = residual
self._cplx = _iscomplex(self._resi.dtype)
self._scale = ScalingOperator(self._domain, 1 if self._cplx else .5)
def apply(self, x):
self._check_input(x)
r = self._resi
if self._cplx:
res = 0.5*(r*x.real).vdot(r).real - x.ptw("log").sum()
else:
res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
if not x.want_metric:
return res
met = makeOp((self._scale(x.val))**(-2))
return res.add_metric(SamplingDtypeSetter(met, self._resi.dtype))
class GaussianEnergy(EnergyOperator):
"""Computes a negative-log Gaussian.
......@@ -225,14 +270,13 @@ class GaussianEnergy(EnergyOperator):
if sampling_dtype != _field_to_dtype(self._mean):
raise ValueError("Sampling dtype and mean not compatible")
self._icov = inverse_covariance
if inverse_covariance is None:
self._op = Squared2NormOperator(self._domain).scale(0.5)
self._met = ScalingOperator(self._domain, 1)
self._trivial_invcov = True
else:
self._op = QuadraticFormOperator(inverse_covariance)
self._met = inverse_covariance
self._trivial_invcov = False
if sampling_dtype is not None:
self._met = SamplingDtypeSetter(self._met, sampling_dtype)
......
......@@ -276,6 +276,8 @@ class Operator(metaclass=NiftyMeta):
if c_inp is None:
return None, self
# Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain
if isinstance(self.domain, MultiDomain):
assert isinstance(c_inp.domain, MultiDomain)
if set(c_inp.keys()) > set(self.domain.keys()):
......
......@@ -102,10 +102,8 @@ class SliceOperator(LinearOperator):
return Field.from_raw(self.domain, res)
def __str__(self):
ss = (
f"{self.__class__.__name__}"
f"({self.domain.shape} -> {self.target.shape})"
)
ss = (f"{self.__class__.__name__}"
f"({self.domain.shape} -> {self.target.shape})")
return ss
......
......@@ -173,16 +173,9 @@ class FieldAdapter(LinearOperator):
return MultiField(self._tgt(mode), (x,))
def __repr__(self):
s = 'FieldAdapter'
dom = isinstance(self._domain, MultiDomain)
tgt = isinstance(self._target, MultiDomain)
if dom and tgt:
s += ' {} <- {}'.format(self._target.keys(), self._domain.keys())
elif dom:
s += ' <- {}'.format(self._domain.keys())
elif tgt:
s += ' {} <-'.format(self._target.keys())
return s
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- {dom}'
class _SlowFieldAdapter(LinearOperator):
......@@ -354,6 +347,12 @@ class NullOperator(LinearOperator):
tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
return f'{tgt} <- NullOperator <- {dom}'
def draw_sample(self, from_inverse=False):
if self._domain is not self._target:
raise RuntimeError
from ..sugar import full
return full(self._domain, 0.)
class PartialExtractor(LinearOperator):
def __init__(self, domain, target):
......@@ -378,3 +377,6 @@ class PartialExtractor(LinearOperator):
res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res1 = MultiField.full(self._compldomain, 0.)
return res0.unite(res1)
def __repr__(self):
return f'{self.target.keys()} <- {self.domain.keys()}'
......@@ -109,7 +109,10 @@ class SlowPartialConstantOperator(Operator):
class ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
from ..sugar import makeDomain
from ..field import Field
self._domain = makeDomain(dom)
if not isinstance(output, Field):
output = Field.scalar(float(output))
if self.target is not output.domain:
raise TypeError
self._output = output
......
......@@ -88,6 +88,15 @@ def test_variablecovariancegaussian(field):
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_specialgamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
energy = ift.operators.energy_operators._SpecialGammaEnergy(field)
loc = ift.from_random(energy.domain).exp()
ift.extra.check_jacobian_consistency(energy, loc, tol=1e-6, ntries=ntries)
energy(ift.Linearization.make_var(loc, want_metric=True)).metric.draw_sample()
def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
......
......@@ -28,6 +28,7 @@ def _flat_PS(k):
pmp = pytest.mark.parametrize
ntries = 10
@pmp('space', [ift.GLSpace(5),
......@@ -70,4 +71,55 @@ def test_gaussian_energy(space, nonlinearity, noise, seed):
energy = ift.GaussianEnergy(d, N) @ d_model()
ift.extra.check_jacobian_consistency(
energy, xi0, ntries=10, tol=1e-6)
energy, xi0, ntries=ntries, tol=1e-6)
@pmp('cplx', [True, False])
def testgaussianenergy_compatibility(cplx):
dt = np.complex128 if cplx else np.float64
dom = ift.UnstructuredDomain(3)
e = ift.VariableCovarianceGaussianEnergy(dom, 'resi', 'icov', dt)
resi = ift.from_random(dom)
if cplx:
resi = resi + 1j*ift.from_random(dom)
loc0 = ift.MultiField.from_dict({'resi': resi})
loc1 = ift.MultiField.from_dict({'icov': ift.from_random(dom).exp()})
loc = loc0.unite(loc1)
val0 = e(loc).val
_, e0 = e.simplify_for_constant_input(loc0)
val1 = e0(loc).val
val2 = e0(loc.unite(loc0)).val
np.testing.assert_equal(val1, val2)
np.testing.assert_equal(val0, val1)
_, e1 = e.simplify_for_constant_input(loc1)
val1 = e1(loc).val
val2 = e1(loc.unite(loc1)).val
np.testing.assert_equal(val0, val1)
np.testing.assert_equal(val1, val2)
ift.extra.check_jacobian_consistency(e, loc, ntries=ntries)
ift.extra.check_jacobian_consistency(e0, loc, ntries=ntries)
ift.extra.check_jacobian_consistency(e1, loc, ntries=ntries)
# Test jacobian is zero
lin = ift.Linearization.make_var(loc, want_metric=True)
grad = e(lin).gradient.val
grad0 = e0(lin).gradient.val
grad1 = e1(lin).gradient.val
samp = e(lin).metric.draw_sample().val
samp0 = e0(lin).metric.draw_sample().val
samp1 = e1(lin).metric.draw_sample().val
np.testing.assert_equal(samp0['resi'], 0.)
np.testing.assert_equal(samp1['icov'], 0.)
np.testing.assert_equal(grad0['resi'], 0.)
np.testing.assert_equal(grad1['icov'], 0.)
np.testing.assert_(all(samp['resi'] != 0))
np.testing.assert_(all(samp['icov'] != 0))
np.testing.assert_(all(samp0['icov'] != 0))
np.testing.assert_(all(samp1['resi'] != 0))
np.testing.assert_(all(grad['resi'] != 0))
np.testing.assert_(all(grad['icov'] != 0))
np.testing.assert_(all(grad0['icov'] != 0))
np.testing.assert_(all(grad1['resi'] != 0))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment