Commit 952f9dd9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'better_errors' into 'NIFTy_8'

Better errors, plotting and complex jax operators

See merge request !661
parents b4b9b12d 7d0d885f
Pipeline #105701 passed with stages
in 20 minutes and 59 seconds
Changes since NIFTy 7
=====================
Minisanity
----------
Terminal colors can be disabled in order to make the output of
`ift.extra.minisanity` more readable when written to a file.
Jax interface
-------------
......
......@@ -364,7 +364,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin = op(lin)
myassert(oplin.jac.target is oplin0.jac.target)
rndinp = from_random(oplin.jac.target)
rndinp = from_random(oplin.jac.target, dtype=oplin.val.dtype)
assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain),
oplin0.jac.adjoint(rndinp), 1e-13, 1e-13)
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
......@@ -408,7 +408,8 @@ def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
atol=tol**2, rtol=tol**2)
def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None,
terminal_colors=True):
"""Log information about the current fit quality and prior compatibility.
Log a table with fitting information for the likelihood and the prior.
......@@ -446,6 +447,10 @@ def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
samples : iterable of Field or MultiField, optional
Residual samples around `mean`. Default: no samples.
terminal_colors : bool, optional
Setting this to false disables terminal colors. This may be useful if
the output of minisanity is written to a file. Default: True
Note
----
For computing the reduced chi^2 values and the normalized residuals, the
......@@ -459,6 +464,7 @@ def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
and is_fieldlike(mean)
):
raise TypeError
colors = bool(terminal_colors)
keylen = 18
for dom in [data.domain, mean.domain]:
if isinstance(dom, MultiDomain):
......@@ -486,8 +492,8 @@ def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
xscmean[aa][kk].add(np.nanmean(rr[kk].val))
xndof[aa][kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val))
s0 = _tableentries(xredchisq[0], xscmean[0], xndof[0], keylen)
s1 = _tableentries(xredchisq[1], xscmean[1], xndof[1], keylen)
s0 = _tableentries(xredchisq[0], xscmean[0], xndof[0], keylen, colors)
s1 = _tableentries(xredchisq[1], xscmean[1], xndof[1], keylen, colors)
f = logger.info
n = 38 + keylen
......@@ -504,14 +510,14 @@ def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
f(n * "=")
class _bcolors:
WARNING = "\033[33m"
FAIL = "\033[31m"
ENDC = "\033[0m"
BOLD = "\033[1m"
def _tableentries(redchisq, scmean, ndof, keylen, colors):
class _bcolors:
WARNING = "\033[33m" if colors else ""
FAIL = "\033[31m" if colors else ""
ENDC = "\033[0m" if colors else ""
BOLD = "\033[1m" if colors else ""
def _tableentries(redchisq, scmean, ndof, keylen):
out = ""
for kk in redchisq.keys():
if len(kk) > keylen:
......
......@@ -56,7 +56,7 @@ class Field(Operator):
else:
raise TypeError("val must be of type numpy.ndarray")
if domain.shape != val.shape:
raise ValueError("shape mismatch between val and domain")
raise ValueError(f"shape mismatch between val and domain\n{domain.shape}\n{val.shape}")
self._domain = domain
self._val = val
self._val.flags.writeable = False
......@@ -310,8 +310,7 @@ class Field(Operator):
raise TypeError("The dot-partner must be an instance of " +
"the Field class")
if x._domain != self._domain:
raise ValueError("Domain mismatch")
utilities.check_domain_equality(x._domain, self._domain)
ndom = len(self._domain)
spaces = utilities.parse_spaces(spaces, ndom)
......@@ -339,8 +338,7 @@ class Field(Operator):
raise TypeError("The dot-partner must be an instance of " +
"the Field class")
if x._domain != self._domain:
raise ValueError("Domain mismatch")
utilities.check_domain_equality(x._domain, self._domain)
return vdot(self._val, x._val)
......@@ -671,13 +669,11 @@ class Field(Operator):
"\n- val = " + repr(self._val)
def extract(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(dom, self._domain)
return self
def extract_part(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(dom, self._domain)
return self
def unite(self, other):
......@@ -690,8 +686,7 @@ class Field(Operator):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
utilities.check_domain_equality(other._domain, self._domain)
return Field(self._domain, f(other._val))
if np.isscalar(other):
return Field(self._domain, f(other))
......
......@@ -19,6 +19,7 @@ import numpy as np
from .operators.operator import Operator
from .sugar import makeOp
from .utilities import check_domain_equality
class Linearization(Operator):
......@@ -41,8 +42,7 @@ class Linearization(Operator):
def __init__(self, val, jac, metric=None, want_metric=False):
self._val = val
self._jac = jac
if self._val.domain != self._jac.target:
raise ValueError("domain mismatch")
check_domain_equality(self._val.domain, self._jac.target)
self._want_metric = want_metric
self._metric = metric
......@@ -179,13 +179,10 @@ class Linearization(Operator):
return self
met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met)
from .sugar import makeOp
if other.jac is None:
if self.target != other.domain:
raise ValueError("domain mismatch")
check_domain_equality(self.target, other.domain)
return self.new(self._val*other, makeOp(other)(self._jac))
if self.target != other.target:
raise ValueError("domain mismatch")
check_domain_equality(self.target, other.target)
return self.new(
self.val*other.val,
(makeOp(other.val)(self.jac))._myadd(
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .domain_tuple import DomainTuple
from .utilities import frozendict, indent
from .utilities import check_domain_equality, frozendict, indent
class MultiDomain:
......@@ -128,8 +128,7 @@ class MultiDomain:
for dom in inp:
for key, subdom in zip(dom._keys, dom._domains):
if key in res:
if res[key] != subdom:
raise ValueError("domain mismatch")
check_domain_equality(res[key], subdom)
else:
res[key] = subdom
return MultiDomain.make(res)
......
......@@ -41,8 +41,7 @@ class MultiField(Operator):
raise ValueError("length mismatch")
for d, v in zip(domain._domains, val):
if isinstance(v, Field):
if v._domain != d:
raise ValueError("domain mismatch")
utilities.check_domain_equality(v._domain, d)
else:
raise TypeError("bad entry in val (must be Field)")
self._domain = domain
......@@ -137,13 +136,9 @@ class MultiField(Operator):
for kk in domain.keys()}
return MultiField.from_dict(dct)
def _check_domain(self, other):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
def s_vdot(self, x):
result = 0.
self._check_domain(x)
utilities.check_domain_equality(x._domain, self._domain)
for v1, v2 in zip(self._val, x._val):
result += v1.s_vdot(v2)
return result
......@@ -369,8 +364,7 @@ class MultiField(Operator):
def _binary_op(self, other, op):
f = getattr(Field, op)
if isinstance(other, MultiField):
if self._domain != other._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(self._domain, other._domain)
val = tuple(f(v1, v2)
for v1, v2 in zip(self._val, other._val))
else:
......
......@@ -17,7 +17,7 @@
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..utilities import indent
from ..utilities import check_domain_equality, indent
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -75,16 +75,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
return MultiField(self._domain, val)
def _combine_chain(self, op):
if self._domain != op._domain:
raise ValueError("domain mismatch")
check_domain_equality(self._domain, op._domain)
res = {key: v1(v2)
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
if self._domain != op._domain:
raise ValueError("domain mismatch")
check_domain_equality(self._domain, op._domain)
res = {key: SumOperator.make([v1, v2], [selfneg, opneg])
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
......
......@@ -44,8 +44,7 @@ class ChainOperator(LinearOperator):
def simplify(ops):
# verify domains
for i in range(len(ops) - 1):
if ops[i + 1].target != ops[i].domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(ops[i + 1].target, ops[i].domain)
# unpack ChainOperators
opsnew = []
for op in ops:
......
......@@ -69,8 +69,7 @@ def _ConvolutionOperator(domain, kernel, space=None):
lm = [d for d in domain]
lm[space] = lm[space].get_default_codomain()
lm = DomainTuple.make(lm)
if lm[space] != kernel.domain[0]:
raise ValueError("Input domain and kernel are incompatible")
utilities.check_domain_equality(lm[space], kernel.domain[0])
HT = HarmonicTransformOperator(lm, domain[space], space)
diag = DiagonalOperator(kernel*domain[space].total_volume, lm, (space,))
wgt = WeightApplier(domain, space, 1)
......
......@@ -61,15 +61,13 @@ class DiagonalOperator(EndomorphicOperator):
self._domain = DomainTuple.make(domain)
if spaces is None:
self._spaces = None
if diagonal.domain != self._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(diagonal.domain, self._domain)
else:
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) != len(diagonal.domain):
raise ValueError("spaces and domain must have the same length")
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
utilities.check_domain_equality(diagonal.domain[i], self._domain[j])
if self._spaces == tuple(range(len(self._domain))):
self._spaces = None # shortcut
......
......@@ -347,8 +347,7 @@ class GaussianEnergy(LikelihoodEnergyOperator):
if self._domain is None:
self._domain = newdom
else:
if self._domain != newdom:
raise ValueError("domain mismatch")
utilities.check_domain_equality(self._domain, newdom)
def apply(self, x):
self._check_input(x)
......
......@@ -61,13 +61,38 @@ class JaxOperator(Operator):
def apply(self, x):
from ..sugar import is_linearization, makeField
from ..multi_domain import MultiDomain
self._check_input(x)
if is_linearization(x):
res, bwd = self._vjp(x.val.val)
fwd = lambda y: self._fwd(x.val.val, y)
jac = _JaxJacobian(self._domain, self._target, fwd, bwd)
return x.new(makeField(self._target, _jax2np(res)), jac)
return makeField(self._target, _jax2np(self._func(x.val)))
res = _jax2np(self._func(x.val))
if isinstance(res, dict):
if not isinstance(self._target, MultiDomain):
raise TypeError(("Jax function returns a dictionary although the "
"target of the operator is a DomainTuple."))
if set(res.keys()) != set(self._target.keys()):
raise ValueError(("Keys do not match:\n"
f"Target keys: {self._target.keys()}\n"
f"Jax function returns: {res.keys()}"))
for kk in res.keys():
self._check_shape(self._target[kk].shape, res[kk].shape)
else:
if isinstance(self._target, MultiDomain):
raise TypeError(("Jax function does not return a dictionary "
"although the target of the operator is a "
"MultiDomain."))
self._check_shape(self._target.shape, res.shape)
return makeField(self._target, res)
@staticmethod
def _check_shape(shp_tgt, shp_jax):
if shp_tgt != shp_jax:
raise ValueError(("Output shapes do not match:\n"
f"Target shape is\t\t{shp_tgt}\n"
f"Jax function returns\t{shp_jax}"))
def _simplify_for_constant_input_nontrivial(self, c_inp):
func2 = lambda x: self._func({**x, **c_inp.val})
......@@ -140,12 +165,12 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
class _JaxJacobian(LinearOperator):
def __init__(self, domain, target, func, adjfunc):
def __init__(self, domain, target, func, func_transposed):
from ..sugar import makeDomain
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = func
self._adjfunc = adjfunc
self._func_transposed = func_transposed
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
......@@ -153,6 +178,6 @@ class _JaxJacobian(LinearOperator):
self._check_input(x, mode)
if mode == self.TIMES:
fx = self._func(x.val)
else:
fx = self._adjfunc(x.val)[0]
return makeField(self._tgt(mode), _jax2np(fx))
return makeField(self._tgt(mode), _jax2np(fx))
fx = self._func_transposed(x.conjugate().val)[0]
return makeField(self._tgt(mode), _jax2np(fx)).conjugate()
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .operator import Operator
from ..utilities import check_domain_equality
class LinearOperator(Operator):
......@@ -252,4 +253,4 @@ class LinearOperator(Operator):
def _check_input(self, x, mode):
self._check_mode(mode)
self._check_domain_equality(self._dom(mode), x.domain)
check_domain_equality(self._dom(mode), x.domain)
......@@ -20,7 +20,7 @@ import numpy as np
from .. import pointwise
from ..logger import logger
from ..multi_domain import MultiDomain
from ..utilities import NiftyMeta, indent, myassert
from ..utilities import NiftyMeta, check_domain_equality, indent, myassert
class Operator(metaclass=NiftyMeta):
......@@ -131,17 +131,6 @@ class Operator(metaclass=NiftyMeta):
"""
return None
@staticmethod
def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field:
s = "The operator's and field's domains don't match."
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
if not isinstance(dom_op, (DomainTuple, MultiDomain,)):
s += " Your operator's domain is neither a `DomainTuple`" \
" nor a `MultiDomain`."
raise ValueError(s)
def scale(self, factor):
if factor == 1:
return self
......@@ -275,7 +264,7 @@ class Operator(metaclass=NiftyMeta):
raise ValueError
if x.jac._factor != 1:
raise ValueError
self._check_domain_equality(self._domain, x.domain)
check_domain_equality(self._domain, x.domain)
def __call__(self, x):
if not isinstance(x, Operator):
......@@ -417,8 +406,7 @@ class _OpChain(_CombinedOperator):
self._domain = self._ops[-1].domain
self._target = self._ops[0].target
for i in range(1, len(self._ops)):
if self._ops[i-1].domain != self._ops[i].target:
raise ValueError("domain mismatch")
check_domain_equality(self._ops[i-1].domain, self._ops[i].target)
def apply(self, x):
self._check_input(x)
......
......@@ -174,11 +174,10 @@ class SplitOperator(LinearOperator):
elif isinstance(slc[i],
np.ndarray) and slc[i].dtype is np.dtype(bool):
if slc[i].size != d.size:
ve = (
raise ValueError(
"shape mismatch between desired slice {slc[i]}"
"and the shape of the domain {d.size}"
)
raise ValueError(ve)
k_tgt += [UnstructuredDomain(slc[i].sum())]
k_slc_by_ax += [slc[i]]
elif isinstance(slc[i], (tuple, list, np.ndarray)):
......
......@@ -22,6 +22,7 @@ from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..utilities import check_domain_equality
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -363,8 +364,7 @@ class PartialExtractor(LinearOperator):
self._domain = domain
self._target = target
for key in self._target.keys():
if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch")
check_domain_equality(self._domain[key], self._target[key])
self._capability = self.TIMES | self.ADJOINT_TIMES
self._compldomain = MultiDomain.make({kk: self._domain[kk]
for kk in self._domain.keys()
......
......@@ -26,6 +26,7 @@ from .domains.power_space import PowerSpace
from .domains.rg_space import RGSpace
from .field import Field
from .minimization.iteration_controllers import EnergyHistory
from .utilities import check_domain_equality, myassert
# relevant properties:
# - x/y size
......@@ -313,8 +314,7 @@ def _plot1D(f, ax, **kwargs):
if (len(dom) != 1):
raise ValueError("input field must have exactly one domain")
else:
if fld.domain != dom:
raise ValueError("domain mismatch")
check_domain_equality(fld.domain, dom)
dom = dom[0]
label = kwargs.pop("label", None)
......@@ -329,6 +329,10 @@ def _plot1D(f, ax, **kwargs):
if not isinstance(alpha, list):
alpha = [alpha] * len(f)
color = kwargs.pop("color", None)
if not isinstance(color, list):
color = [color] * len(f)
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
......@@ -341,7 +345,8 @@ def _plot1D(f, ax, **kwargs):
for i, fld in enumerate(f):
ycoord = fld.val
plt.plot(xcoord, ycoord, label=label[i],
linewidth=linewidth[i], alpha=alpha[i])
linewidth=linewidth[i], alpha=alpha[i],
color=color[i])
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
......@@ -354,7 +359,8 @@ def _plot1D(f, ax, **kwargs):
ycoord = fld.val_rw()
ycoord[0] = ycoord[1]
plt.plot(xcoord, ycoord, label=label[i],
linewidth=linewidth[i], alpha=alpha[i])
linewidth=linewidth[i], alpha=alpha[i],
color=color[i])
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
......@@ -572,7 +578,9 @@ class Plot:
nx = kwargs.pop("nx", 0)
ny = kwargs.pop("ny", 0)
if nx == ny == 0:
nx = ny = int(np.ceil(np.sqrt(nplot)))
ny = int(np.ceil(np.sqrt(nplot)))
nx = int(np.ceil(nplot/ny))
myassert(nx*ny >= nplot)
elif nx == 0:
nx = int(np.ceil(nplot/ny))
elif ny == 0:
......
......@@ -420,8 +420,7 @@ def makeOp(input, dom=None):
raise TypeError("need proper `dom` argument")
return ScalingOperator(dom, input)
if dom is not None:
if not dom == input.domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(dom, input.domain)
if input.domain is DomainTuple.scalar_domain():
return ScalingOperator(input.domain, input.val[()])
if isinstance(input, Field):
......@@ -442,8 +441,8 @@ def domain_union(domains):
- if MultiDomain, there must not be any conflicting components
"""
if isinstance(domains[0], DomainTuple):
if any(dom != domains[0] for dom in domains[1:]):
raise ValueError("domain mismatch")
for dom in domains[1:]:
utilities.check_domain_equality(dom, domains[0])
return domains[0]
return MultiDomain.union(domains)
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -25,7 +25,7 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
"value_reshaper", "lognormal_moments", "check_domain_equality"]
def my_sum(iterable):
......@@ -412,3 +412,19 @@ def myassert(val):
`__debug__` is False."""
if not val:
raise AssertionError
def check_domain_equality(domain0, domain1):
"""Check if two domains are equal and throw ValueError if not. Throw a
TypeError if one of the inputs is neither a DomainTuple nor a
MultiDomain.
"""
from .domain_tuple import DomainTuple
from .multi_domain import MultiDomain
from .domains.domain import Domain
for dom in [domain0, domain1]:
if not isinstance(dom, (MultiDomain, DomainTuple, Domain)):
raise TypeError("The following domain is neither an instance of "
f"ift.MultiDomain nor of ift.DomainTuple.\n{dom}")
if domain0 != domain1:
raise ValueError(f"Domain mismatch:\n{domain0}\n{domain1}")
......@@ -86,3 +86,59 @@ def test_jax_energy(dom):
continue
pos1 = ift.from_random(e.domain)
ift.extra.assert_allclose(e0(lin).metric(pos1), e(lin).metric(pos1))
def test_jax_errors():
dom = ift.UnstructuredDomain(2)
mdom = {"a": dom}
op = ift.JaxOperator(dom, dom, lambda x: {"a": x})
fld = ift.full(dom, 0.)
with pytest.raises(TypeError):
op(fld)
op = ift.JaxOperator(dom, mdom, lambda x: x)
with pytest.raises(TypeError):
op(fld)
op = ift.JaxOperator(dom, dom, lambda x: x[0])
with pytest.raises(ValueError):
op(fld)
op = ift.JaxOperator(dom, mdom, lambda x: {"a": x[0]})
with pytest.raises(ValueError):
op(fld)