Commit 2db6c449 authored by Philipp Arras's avatar Philipp Arras
Browse files

More informative errors for domain mismatches

parent b4b9b12d
......@@ -56,7 +56,7 @@ class Field(Operator):
else:
raise TypeError("val must be of type numpy.ndarray")
if domain.shape != val.shape:
raise ValueError("shape mismatch between val and domain")
raise ValueError(f"shape mismatch between val and domain\n{domain.shape}\n{val.shape}")
self._domain = domain
self._val = val
self._val.flags.writeable = False
......@@ -310,8 +310,7 @@ class Field(Operator):
raise TypeError("The dot-partner must be an instance of " +
"the Field class")
if x._domain != self._domain:
raise ValueError("Domain mismatch")
utilities.check_domain_equality(x._domain, self._domain)
ndom = len(self._domain)
spaces = utilities.parse_spaces(spaces, ndom)
......@@ -339,8 +338,7 @@ class Field(Operator):
raise TypeError("The dot-partner must be an instance of " +
"the Field class")
if x._domain != self._domain:
raise ValueError("Domain mismatch")
utilities.check_domain_equality(x._domain, self._domain)
return vdot(self._val, x._val)
......@@ -671,13 +669,11 @@ class Field(Operator):
"\n- val = " + repr(self._val)
def extract(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(dom, self._domain)
return self
def extract_part(self, dom):
if dom != self._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(dom, self._domain)
return self
def unite(self, other):
......@@ -690,8 +686,7 @@ class Field(Operator):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
utilities.check_domain_equality(other._domain, self._domain)
return Field(self._domain, f(other._val))
if np.isscalar(other):
return Field(self._domain, f(other))
......
......@@ -19,6 +19,7 @@ import numpy as np
from .operators.operator import Operator
from .sugar import makeOp
from .utilities import check_domain_equality
class Linearization(Operator):
......@@ -41,8 +42,7 @@ class Linearization(Operator):
def __init__(self, val, jac, metric=None, want_metric=False):
self._val = val
self._jac = jac
if self._val.domain != self._jac.target:
raise ValueError("domain mismatch")
check_domain_equality(self._val.domain, self._jac.target)
self._want_metric = want_metric
self._metric = metric
......@@ -179,13 +179,10 @@ class Linearization(Operator):
return self
met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met)
from .sugar import makeOp
if other.jac is None:
if self.target != other.domain:
raise ValueError("domain mismatch")
check_domain_equality(self.target, other.domain)
return self.new(self._val*other, makeOp(other)(self._jac))
if self.target != other.target:
raise ValueError("domain mismatch")
check_domain_equality(self.target, other.target)
return self.new(
self.val*other.val,
(makeOp(other.val)(self.jac))._myadd(
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .domain_tuple import DomainTuple
from .utilities import frozendict, indent
from .utilities import check_domain_equality, frozendict, indent
class MultiDomain:
......@@ -128,8 +128,7 @@ class MultiDomain:
for dom in inp:
for key, subdom in zip(dom._keys, dom._domains):
if key in res:
if res[key] != subdom:
raise ValueError("domain mismatch")
check_domain_equality(res[key], subdom)
else:
res[key] = subdom
return MultiDomain.make(res)
......
......@@ -41,8 +41,7 @@ class MultiField(Operator):
raise ValueError("length mismatch")
for d, v in zip(domain._domains, val):
if isinstance(v, Field):
if v._domain != d:
raise ValueError("domain mismatch")
utilities.check_domain_equality(v._domain, d)
else:
raise TypeError("bad entry in val (must be Field)")
self._domain = domain
......@@ -137,13 +136,9 @@ class MultiField(Operator):
for kk in domain.keys()}
return MultiField.from_dict(dct)
def _check_domain(self, other):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
def s_vdot(self, x):
result = 0.
self._check_domain(x)
utilities.check_domain_equality(x._domain, self._domain)
for v1, v2 in zip(self._val, x._val):
result += v1.s_vdot(v2)
return result
......@@ -369,8 +364,7 @@ class MultiField(Operator):
def _binary_op(self, other, op):
f = getattr(Field, op)
if isinstance(other, MultiField):
if self._domain != other._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(self._domain, other._domain)
val = tuple(f(v1, v2)
for v1, v2 in zip(self._val, other._val))
else:
......
......@@ -17,7 +17,7 @@
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..utilities import indent
from ..utilities import check_domain_equality, indent
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -75,16 +75,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
return MultiField(self._domain, val)
def _combine_chain(self, op):
if self._domain != op._domain:
raise ValueError("domain mismatch")
check_domain_equality(self._domain, op._domain)
res = {key: v1(v2)
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
from ..operators.sum_operator import SumOperator
if self._domain != op._domain:
raise ValueError("domain mismatch")
check_domain_equality(self._domain, op._domain)
res = {key: SumOperator.make([v1, v2], [selfneg, opneg])
for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
return BlockDiagonalOperator(self._domain, res)
......
......@@ -44,8 +44,7 @@ class ChainOperator(LinearOperator):
def simplify(ops):
# verify domains
for i in range(len(ops) - 1):
if ops[i + 1].target != ops[i].domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(ops[i + 1].target, ops[i].domain)
# unpack ChainOperators
opsnew = []
for op in ops:
......
......@@ -69,8 +69,7 @@ def _ConvolutionOperator(domain, kernel, space=None):
lm = [d for d in domain]
lm[space] = lm[space].get_default_codomain()
lm = DomainTuple.make(lm)
if lm[space] != kernel.domain[0]:
raise ValueError("Input domain and kernel are incompatible")
utilities.check_domain_equality(lm[space], kernel.domain[0])
HT = HarmonicTransformOperator(lm, domain[space], space)
diag = DiagonalOperator(kernel*domain[space].total_volume, lm, (space,))
wgt = WeightApplier(domain, space, 1)
......
......@@ -61,15 +61,13 @@ class DiagonalOperator(EndomorphicOperator):
self._domain = DomainTuple.make(domain)
if spaces is None:
self._spaces = None
if diagonal.domain != self._domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(diagonal.domain, self._domain)
else:
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
if len(self._spaces) != len(diagonal.domain):
raise ValueError("spaces and domain must have the same length")
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
utilities.check_domain_equality(diagonal.domain[i], self._domain[j])
if self._spaces == tuple(range(len(self._domain))):
self._spaces = None # shortcut
......
......@@ -347,8 +347,7 @@ class GaussianEnergy(LikelihoodEnergyOperator):
if self._domain is None:
self._domain = newdom
else:
if self._domain != newdom:
raise ValueError("domain mismatch")
utilities.check_domain_equality(self._domain, newdom)
def apply(self, x):
self._check_input(x)
......
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .operator import Operator
from ..utilities import check_domain_equality
class LinearOperator(Operator):
......@@ -252,4 +253,4 @@ class LinearOperator(Operator):
def _check_input(self, x, mode):
self._check_mode(mode)
self._check_domain_equality(self._dom(mode), x.domain)
check_domain_equality(self._dom(mode), x.domain)
......@@ -20,7 +20,7 @@ import numpy as np
from .. import pointwise
from ..logger import logger
from ..multi_domain import MultiDomain
from ..utilities import NiftyMeta, indent, myassert
from ..utilities import NiftyMeta, check_domain_equality, indent, myassert
class Operator(metaclass=NiftyMeta):
......@@ -131,17 +131,6 @@ class Operator(metaclass=NiftyMeta):
"""
return None
@staticmethod
def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field:
s = "The operator's and field's domains don't match."
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
if not isinstance(dom_op, (DomainTuple, MultiDomain,)):
s += " Your operator's domain is neither a `DomainTuple`" \
" nor a `MultiDomain`."
raise ValueError(s)
def scale(self, factor):
if factor == 1:
return self
......@@ -275,7 +264,7 @@ class Operator(metaclass=NiftyMeta):
raise ValueError
if x.jac._factor != 1:
raise ValueError
self._check_domain_equality(self._domain, x.domain)
check_domain_equality(self._domain, x.domain)
def __call__(self, x):
if not isinstance(x, Operator):
......@@ -417,8 +406,7 @@ class _OpChain(_CombinedOperator):
self._domain = self._ops[-1].domain
self._target = self._ops[0].target
for i in range(1, len(self._ops)):
if self._ops[i-1].domain != self._ops[i].target:
raise ValueError("domain mismatch")
check_domain_equality(self._ops[i-1].domain, self._ops[i].target)
def apply(self, x):
self._check_input(x)
......
......@@ -22,6 +22,7 @@ from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..utilities import check_domain_equality
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -363,8 +364,7 @@ class PartialExtractor(LinearOperator):
self._domain = domain
self._target = target
for key in self._target.keys():
if self._domain[key] is not self._target[key]:
raise ValueError("domain mismatch")
check_domain_equality(self._domain[key], self._target[key])
self._capability = self.TIMES | self.ADJOINT_TIMES
self._compldomain = MultiDomain.make({kk: self._domain[kk]
for kk in self._domain.keys()
......
......@@ -26,6 +26,7 @@ from .domains.power_space import PowerSpace
from .domains.rg_space import RGSpace
from .field import Field
from .minimization.iteration_controllers import EnergyHistory
from .utilities import check_domain_equality
# relevant properties:
# - x/y size
......@@ -313,8 +314,7 @@ def _plot1D(f, ax, **kwargs):
if (len(dom) != 1):
raise ValueError("input field must have exactly one domain")
else:
if fld.domain != dom:
raise ValueError("domain mismatch")
check_domain_equality(fld.domain, dom)
dom = dom[0]
label = kwargs.pop("label", None)
......
......@@ -420,8 +420,7 @@ def makeOp(input, dom=None):
raise TypeError("need proper `dom` argument")
return ScalingOperator(dom, input)
if dom is not None:
if not dom == input.domain:
raise ValueError("domain mismatch")
utilities.check_domain_equality(dom, input.domain)
if input.domain is DomainTuple.scalar_domain():
return ScalingOperator(input.domain, input.val[()])
if isinstance(input, Field):
......@@ -442,8 +441,8 @@ def domain_union(domains):
- if MultiDomain, there must not be any conflicting components
"""
if isinstance(domains[0], DomainTuple):
if any(dom != domains[0] for dom in domains[1:]):
raise ValueError("domain mismatch")
for dom in domains[1:]:
utilities.check_domain_equality(dom, domains[0])
return domains[0]
return MultiDomain.union(domains)
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -25,7 +25,7 @@ __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
"value_reshaper", "lognormal_moments", "check_domain_equality"]
def my_sum(iterable):
......@@ -412,3 +412,19 @@ def myassert(val):
`__debug__` is False."""
if not val:
raise AssertionError
def check_domain_equality(domain0, domain1):
"""Check if two domains are equal and throw ValueError if not. Throw a
TypeError if one of the inputs is neither a DomainTuple nor a
MultiDomain.
"""
from .domain_tuple import DomainTuple
from .multi_domain import MultiDomain
from .domains.domain import Domain
for dom in [domain0, domain1]:
if not isinstance(dom, (MultiDomain, DomainTuple, Domain)):
raise TypeError("The following domain is neither an instance of "
f"ift.MultiDomain nor of ift.DomainTuple.\n{dom}")
if domain0 != domain1:
raise ValueError(f"Domain mismatch:\n{domain0}\n{domain1}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment