diff --git a/nifty5/field.py b/nifty5/field.py index d670a1316cba6e4c0dff29ea07f1017ff56223e8..9865d7137c04fe21266abae62edbd0ff4b992b30 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -630,7 +630,6 @@ for op in ["__add__", "__radd__", tval = getattr(self._val, op)(other) return Field(self._domain, tval) - raise TypeError("should not arrive here") return NotImplemented return func2 setattr(Field, op, func(op)) diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py index 3c44572640e5f0170ddae8fa995467878018f8ec..0ee207f811cc133535391b443234f30709ba8d18 100644 --- a/nifty5/multi/multi_field.py +++ b/nifty5/multi/multi_field.py @@ -229,6 +229,7 @@ class MultiField(object): res[key] = f[key] return MultiField.from_dict(res) + for op in ["__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__", diff --git a/nifty5/operator.py b/nifty5/operator.py index a6abdf5eeadd1834c7bd8f48fda5586d34a9bc88..52f01ca73bb5e23db2dc027097d30fbdd44c8ef3 100644 --- a/nifty5/operator.py +++ b/nifty5/operator.py @@ -6,8 +6,6 @@ import numpy as np from .compat import * from .utilities import NiftyMetaBase -#from ..domain_tuple import DomainTuple -#from ..multi.multi_domain import MultiDomain from .field import Field from .multi.multi_field import MultiField @@ -62,8 +60,8 @@ class Linearization(object): return Linearization(self._val*other._val, self._jac*d2 + d1*other._jac) if isinstance(other, (int, float, complex)): - #if other == 0: - # return ... + # if other == 0: + # return ... return Linearization(self._val*other, self._jac*other) if isinstance(other, (Field, MultiField)): d2 = makeOp(other) @@ -82,11 +80,13 @@ class Linearization(object): def make_var(field): from .operators.scaling_operator import ScalingOperator return Linearization(field, ScalingOperator(1., field.domain)) + @staticmethod def make_const(field): from .operators.null_operator import NullOperator return Linearization(field, NullOperator({}, field.domain)) + class Operator(NiftyMetaBase()): """Transforms values living on one domain into values living on another domain, and can also provide the Jacobian. diff --git a/nifty5/operators/field_adapter.py b/nifty5/operators/field_adapter.py index a1e7c57affba8c570c3b349a928a65abaf336c4e..77290c2af64718b6a005c5d303d7bdb3fecee121 100644 --- a/nifty5/operators/field_adapter.py +++ b/nifty5/operators/field_adapter.py @@ -5,6 +5,7 @@ from .linear_operator import LinearOperator from ..multi.multi_domain import MultiDomain from ..multi.multi_field import MultiField + class FieldAdapter(LinearOperator): def __init__(self, op, name_dom, name_tgt): if name_dom is None: diff --git a/nifty5/operators/relaxed_sum_operator.py b/nifty5/operators/relaxed_sum_operator.py index f7f50a6ec38a36a8b01e2871c6014ddd1c4ba21d..fb666a4ccfd499f8801760c107617c35ef060890 100644 --- a/nifty5/operators/relaxed_sum_operator.py +++ b/nifty5/operators/relaxed_sum_operator.py @@ -23,17 +23,19 @@ import numpy as np from ..compat import * from ..utilities import my_sum from .linear_operator import LinearOperator +from ..sugar import domain_union from ..multi.multi_domain import MultiDomain +from ..multi.multi_field import MultiField class RelaxedSumOperator(LinearOperator): - """Class representing sums of operators with compatible MultiDomains.""" + """Class representing sums of operators with compatible domains.""" def __init__(self, ops): super(RelaxedSumOperator, self).__init__() self._ops = ops - self._domain = MultiDomain.union([op.domain for op in ops]) - self._target = MultiDomain.union([op.target for op in ops]) + self._domain = domain_union([op.domain for op in ops]) + self._target = domain_union([op.target for op in ops]) self._capability = self.TIMES | self.ADJOINT_TIMES for op in ops: self._capability &= op.capability @@ -58,9 +60,14 @@ class RelaxedSumOperator(LinearOperator): self._check_mode(mode) res = None for op in self._ops: - tmp = x.extract(op._dom(mode), mode) + if isinstance(x.domain, MultiDomain): + x = x.extract(op._dom(mode)) + x = op.apply(x, mode) if res is None: res = tmp else: - res = MultiField.combine([res, tmp]) + if isinstance(x.domain, MultiDomain): + res = MultiField.combine([res, tmp]) + else: + res = res + tmp return res diff --git a/nifty5/sugar.py b/nifty5/sugar.py index cc2a2963c7cdad5d9e246852a6a6c2b0e54ce141..6e18db8991c6bae906370e796d5084fb01f6ee30 100644 --- a/nifty5/sugar.py +++ b/nifty5/sugar.py @@ -38,7 +38,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', 'full', 'from_global_data', 'from_local_data', 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate', - 'get_signal_variance', 'makeOp'] + 'get_signal_variance', 'makeOp', 'domain_union'] def PS_field(pspace, func): @@ -242,6 +242,14 @@ def makeOp(input): input.domain, tuple(makeOp(val) for val in input.values())) raise NotImplementedError + +def domain_union(domains): + if isinstance(domains[0], DomainTuple): + if any(dom is not domains[0] for dom in domains[1:]): + raise ValueError("domain mismatch") + return domains[0] + return MultiDomain.union(domains) + # Arithmetic functions working on Fields