Commit 9cbf7497 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'minimization_sanity' into 'NIFTy_7'

Implement minisanity

See merge request !582
parents 0aa41a92 2b7da115
Pipeline #88062 passed with stages
in 35 minutes and 30 seconds
......@@ -89,6 +89,7 @@ def main():
# Specify noise
data_space = R.target
noise = .001
sig = ift.ScalingOperator(data_space, np.sqrt(noise))
N = ift.ScalingOperator(data_space, noise)
# Generate mock signal and data
......@@ -97,7 +98,7 @@ def main():
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(
name='Sampling', deltaE=0.05, iteration_limit=100)
deltaE=0.05, iteration_limit=100)
ic_newton = ift.AbsDeltaEnergyController(
name='Newton', deltaE=0.5, iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton)
......@@ -125,6 +126,7 @@ def main():
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL, convergence = minimizer(KL)
mean = KL.position
ift.extra.minisanity(KL, data, sig.inverse, signal_response)
# Plot current reconstruction
plot = ift.Plot()
......
......@@ -22,22 +22,24 @@ import numpy as np
from .domain_tuple import DomainTuple
from .field import Field
from .linearization import Linearization
from .minimization.energy import Energy
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.adder import Adder
from .operators.energy_operators import EnergyOperator
from .operators.linear_operator import LinearOperator
from .operators.operator import Operator
from .sugar import from_random
from .operators.scaling_operator import ScalingOperator
from .probing import StatCalculator
from .sugar import from_random, full, is_fieldlike, is_operator
from .utilities import myassert
__all__ = ["check_linear_operator", "check_operator",
"assert_allclose"]
__all__ = ["check_linear_operator", "check_operator", "assert_allclose", "minisanity"]
def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=1e-12, rtol=1e-12, only_r_linear=False):
"""
Checks an operator for algebraic consistency of its capabilities.
"""Checks an operator for algebraic consistency of its capabilities.
Checks whether times(), adjoint_times(), inverse_times() and
adjoint_inverse_times() (if in capability list) is implemented
......@@ -87,8 +89,7 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
only_r_differentiable=True, metric_sampling=True):
"""
Performs various checks of the implementation of linear and nonlinear
"""Performs various checks of the implementation of linear and nonlinear
operators.
Computes the Jacobian with finite differences and compares it to the
......@@ -371,3 +372,123 @@ def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
target_dtype=dirder.dtype,
only_r_linear=only_r_differentiable,
atol=tol**2, rtol=tol**2)
def minisanity(energy, data, sqrtmetric, modeldata_operator):
"""Log information about the current fit quality and prior compatibility.
Log a table with fitting information for the likelihood and the prior.
Assume that the variables in `energy.position.domain` are standard-normal
distributed a priori. The table contains the reduced chi^2 value, the mean
and the number of degrees of freedom for every key of a `MultiDomain`. If
the domain is a `DomainTuple`, the displayed key is `<None>`.
If everything is consistent the reduced chi^2 values should be close to one
and the mean of the data residuals close to zero. If the reduced chi^2 value
in latent space is significantly bigger than one and only one degree of
freedom is present, the mean column gives an indication in which direction
to change the respective hyper parameters.
Ignore all NaN entries in the target of `modeldata_operator` and in `data`.
Print reduced chi-square values above 2 and 5 in orange and red,
respectively.
Parameters
----------
energy : Energy
Energy object which contains current mean and potentially samples.
data : Field or MultiField
Data which is subtracted from the output of `model_data`.
sqrtmetric : LinearOperator
Linear operator which applies the inverse of the square root of the
noise covariance.
model_data : Operator
Operator which generates
"""
if not (
isinstance(energy, Energy)
and isinstance(sqrtmetric, LinearOperator)
and is_operator(modeldata_operator)
and is_fieldlike(data)
):
raise TypeError
normresi = sqrtmetric @ Adder(data, neg=True) @ modeldata_operator
keylen = 18
for dom in [normresi.target, energy.position.domain]:
if isinstance(dom, MultiDomain):
keylen = max([max(map(len, dom.keys())), keylen])
keylen = min([keylen, 42])
s0 = _comp_chisq(normresi, energy, keylen)
s1 = _comp_chisq(ScalingOperator(energy.position.domain, 1), energy, keylen)
from .logger import logger
f = logger.info
n = 38 + keylen
f(n * "=")
f(
(keylen + 2) * " "
+ "{:>11}".format("reduced χ²")
+ "{:>14}".format("mean")
+ "{:>11}".format("# dof")
)
f(n * "-")
f("Data residuals\n" + s0)
f("Latent space\n" + s1)
f(n * "=")
class _bcolors:
WARNING = "\033[33m"
FAIL = "\033[31m"
ENDC = "\033[0m"
BOLD = "\033[1m"
def _comp_chisq(op, energy, keylen):
p = energy.position
hass = hasattr(energy, "samples")
s = energy.samples if hass else [full(energy.domain, 0.0)]
mf = isinstance(op.target, MultiDomain)
if not mf:
op = op.ducktape_left("<None>")
keys = op.target.keys()
redchisq = {kk: StatCalculator() for kk in keys}
mean = {kk: StatCalculator() for kk in keys}
ndof = {}
for ii, ss in enumerate(s):
rr = op(p + ss)
for kk in keys:
redchisq[kk].add(np.nansum(abs(rr[kk].val) ** 2) / rr[kk].size)
mean[kk].add(np.nanmean(rr[kk].val))
ndof[kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val))
out = ""
for kk in redchisq.keys():
if len(kk) > keylen:
out += " " + kk[: keylen - 1] + "…"
else:
out += " " + kk.ljust(keylen)
foo = f"{redchisq[kk].mean:.1f}"
try:
foo += f" ± {np.sqrt(redchisq[kk].var):.1f}"
except RuntimeError:
pass
if redchisq[kk].mean > 5:
out += _bcolors.FAIL + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC
elif redchisq[kk].mean > 2:
out += _bcolors.WARNING + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC
else:
out += f"{foo:>11}"
foo = f"{mean[kk].mean:.1f}"
try:
foo += f" ± {np.sqrt(mean[kk].var):.1f}"
except RuntimeError:
pass
out += f"{foo:>14}"
out += f"{ndof[kk]:>11}"
out += "\n"
return out[:-1]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment