Commit 3f47aec4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'generalize_minisanity' into 'NIFTy_7'

Generalize minisanity

See merge request !588
parents 78f2127e d170f03a
Pipeline #88403 passed with stages
in 26 minutes and 43 seconds
......@@ -126,7 +126,8 @@ def main():
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
ift.extra.minisanity(KL, data, sig.inverse, signal_response)
ift.extra.minisanity(data, lambda x: N.inverse, signal_response,
KL.position, KL.samples)
# Plot current reconstruction
plot = ift.Plot()
......@@ -185,7 +185,7 @@ class DomainTuple(object):
return self._dom.__hash__()
def __eq__(self, x):
return (self is x) or (self._dom == x._dom)
return (self is x) or (isinstance(x, DomainTuple) and self._dom == x._dom)
def __ne__(self, x):
return not self.__eq__(x)
......@@ -22,10 +22,10 @@ import numpy as np
from .domain_tuple import DomainTuple
from .field import Field
from .linearization import Linearization
from import Energy
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.adder import Adder
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.energy_operators import EnergyOperator
from .operators.linear_operator import LinearOperator
from .operators.operator import Operator
......@@ -85,6 +85,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
_full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
rtol, only_r_linear)
_check_sqrt(op, domain_dtype)
_check_sqrt(op.adjoint, target_dtype)
_check_sqrt(op.inverse, target_dtype)
_check_sqrt(op.adjoint.inverse, domain_dtype)
def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
......@@ -197,6 +201,23 @@ def _domain_check_linear(op, domain_dtype=None, inp=None):
myassert(op(inp).domain is
def _check_sqrt(op, domain_dtype):
if not isinstance(op, EndomorphicOperator):
raise RuntimeError("Operator implements get_sqrt() although it is not an endomorphic operator.")
except AttributeError:
sqop = op.get_sqrt()
except (NotImplementedError, ValueError):
fld = from_random(op.domain, dtype=domain_dtype)
a = op(fld)
b = (sqop.adjoint @ sqop)(fld)
return assert_allclose(a, b, rtol=1e-15)
def _domain_check_nonlinear(op, loc):
myassert(isinstance(loc, (Field, MultiField)))
......@@ -374,7 +395,7 @@ def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
atol=tol**2, rtol=tol**2)
def minisanity(energy, data, sqrtmetric, modeldata_operator):
def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
"""Log information about the current fit quality and prior compatibility.
Log a table with fitting information for the likelihood and the prior.
......@@ -395,36 +416,65 @@ def minisanity(energy, data, sqrtmetric, modeldata_operator):
energy : Energy
Energy object which contains current mean and potentially samples.
data : Field or MultiField
Data which is subtracted from the output of `model_data`.
sqrtmetric : LinearOperator
Linear operator which applies the inverse of the square root of the
noise covariance.
metric_at_pos : function
Function which takes a `Field` or `MultiField` in the domain of `mean`
and returns an endomorphic operator which applies the inverse of the
noise covariance in the domain of `data`.
model_data : Operator
Operator which generates
Operator which generates model data.
mean : Field or MultiField
Mean of input of `model_data`.
samples : iterable of Field or MultiField, optional
Residual samples around `mean`. Default: no samples.
For computing the reduced chi^2 values and the normalized residuals, the
metric at `mean` is used.
from .logger import logger
if not (
isinstance(energy, Energy)
and isinstance(sqrtmetric, LinearOperator)
and is_operator(modeldata_operator)
and is_fieldlike(data)
and is_fieldlike(mean)
raise TypeError
normresi = sqrtmetric @ Adder(data, neg=True) @ modeldata_operator
keylen = 18
for dom in [, energy.position.domain]:
for dom in [data.domain, mean.domain]:
if isinstance(dom, MultiDomain):
keylen = max([max(map(len, dom.keys())), keylen])
keylen = min([keylen, 42])
s0 = _comp_chisq(normresi, energy, keylen)
s1 = _comp_chisq(ScalingOperator(energy.position.domain, 1), energy, keylen)
from .logger import logger
op0 = metric_at_pos(mean).get_sqrt() @ Adder(data, neg=True) @ modeldata_operator
op1 = ScalingOperator(mean.domain, 1)
if not isinstance(, MultiDomain):
op0 = op0.ducktape_left("<None>")
if not isinstance(, MultiDomain):
op1 = op1.ducktape_left("<None>")
s = [full(mean.domain, 0.0)] if samples is None else samples
xop = op0, op1
xkeys =,
xredchisq, xscmean, xndof = 2*[None], 2*[None], 2*[None]
for aa in [0, 1]:
xredchisq[aa] = {kk: StatCalculator() for kk in xkeys[aa]}
xscmean[aa] = {kk: StatCalculator() for kk in xkeys[aa]}
xndof[aa] = {}
for ii, ss in enumerate(s):
for aa in [0, 1]:
rr = xop[aa].force(mean.unite(ss))
for kk in xkeys[aa]:
xredchisq[aa][kk].add(np.nansum(abs(rr[kk].val) ** 2) / rr[kk].size)
xndof[aa][kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val))
s0 = _tableentries(xredchisq[0], xscmean[0], xndof[0], keylen)
s1 = _tableentries(xredchisq[1], xscmean[1], xndof[1], keylen)
f =
n = 38 + keylen
......@@ -448,23 +498,7 @@ class _bcolors:
BOLD = "\033[1m"
def _comp_chisq(op, energy, keylen):
p = energy.position
hass = hasattr(energy, "samples")
s = energy.samples if hass else [full(energy.domain, 0.0)]
mf = isinstance(, MultiDomain)
if not mf:
op = op.ducktape_left("<None>")
keys =
redchisq = {kk: StatCalculator() for kk in keys}
mean = {kk: StatCalculator() for kk in keys}
ndof = {}
for ii, ss in enumerate(s):
rr = op.force(p.unite(ss))
for kk in keys:
redchisq[kk].add(np.nansum(abs(rr[kk].val) ** 2) / rr[kk].size)
ndof[kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val))
def _tableentries(redchisq, scmean, ndof, keylen):
out = ""
for kk in redchisq.keys():
if len(kk) > keylen:
......@@ -483,9 +517,9 @@ def _comp_chisq(op, energy, keylen):
out += f"{foo:>11}"
foo = f"{mean[kk].mean:.1f}"
foo = f"{scmean[kk].mean:.1f}"
foo += f" ± {np.sqrt(mean[kk].var):.1f}"
foo += f" ± {np.sqrt(scmean[kk].var):.1f}"
except RuntimeError:
out += f"{foo:>14}"
......@@ -103,7 +103,7 @@ class MultiDomain(object):
def __eq__(self, x):
if self is x:
return True
return list(self.items()) == list(x.items())
return isinstance(x, MultiDomain) and list(self.items()) == list(x.items())
def __ne__(self, x):
return not self.__eq__(x)
......@@ -36,7 +36,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
if not isinstance(domain, MultiDomain):
raise TypeError("MultiDomain expected")
self._domain = domain
self._ops = tuple(operators[key] for key in domain.keys())
self._ops = tuple(operators[key] if key in operators else None for key in domain.keys())
self._capability = self._all_ops
for op in self._ops:
if op is not None:
......@@ -47,6 +47,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
raise TypeError("LinearOperator expected")
def get_sqrt(self):
ops = {}
for ii, kk in enumerate(self._domain.keys()):
if self._ops[ii] is None:
ops[kk] = self._ops[ii].get_sqrt()
return BlockDiagonalOperator(self._domain, ops)
def apply(self, x, mode):
self._check_input(x, mode)
val = tuple(op.apply(v, mode=mode) if op is not None else v
......@@ -166,5 +166,10 @@ class DiagonalOperator(EndomorphicOperator):
res = Field.from_random(domain=self._domain, random_type="normal", dtype=dtype)
return self.process_sample(res, from_inverse)
def get_sqrt(self):
if not np.iscomplexobj(self._ldiag) or (self._ldiag < 0).any():
raise ValueError("get_sqrt() works only for positive definite operators.")
return self._from_ldiag((), np.sqrt(self._ldiag))
def __repr__(self):
return "DiagonalOperator"
......@@ -75,6 +75,20 @@ class EndomorphicOperator(LinearOperator):
raise NotImplementedError
def get_sqrt(self):
"""Return operator op which obeys `self == op.adjoint @ op`.
Note that this function is only implemented for operators with real
Operator which is the square root of `self`
raise NotImplementedError
def _dom(self, mode):
return self._domain
......@@ -93,6 +93,11 @@ class SandwichOperator(EndomorphicOperator):
return self._bun.adjoint_times(
def get_sqrt(self):
if self._cheese is None:
return self._bun
return self._cheese.get_sqrt() @ self._bun
def __repr__(self):
from ..utilities import indent
return "\n".join((
......@@ -95,6 +95,12 @@ class ScalingOperator(EndomorphicOperator):
from ..sugar import from_random
return from_random(domain=self._domain, random_type="normal", dtype=dtype, std=self._get_fct(from_inverse))
def get_sqrt(self):
fct = self._get_fct(False)
if np.iscomplexobj(fct) or fct < 0:
raise ValueError("get_sqrt() works only for positive definite operators.")
return ScalingOperator(self._domain, fct)
def __call__(self, other):
res = EndomorphicOperator.__call__(self, other)
if np.isreal(self._factor) and self._factor >= 0:
......@@ -71,3 +71,13 @@ def test_blockdiagonal():
f1 = op2(ift.full(dom, 1))
for val in f1.values():
assert_equal((val == 40).s_all(), True)
def test_blockdiagonal_nontrivial():
dom = ift.makeDomain({"d1": ift.RGSpace(10), "d2": ift.UnstructuredDomain(2)})
op = ift.BlockDiagonalOperator(dom, {"d1": ift.ScalingOperator(dom["d1"], 2)})
assert op.domain == dom
fld = ift.from_random(dom)
ift.extra.assert_equal(op(fld)["d1"], 2*fld["d1"])
ift.extra.assert_equal(op(fld)["d2"], fld["d2"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment