From 3d1fadb69cbd10ed689fcbe3f8bd8ac1b9a10dd4 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Fri, 22 Jun 2018 11:50:20 +0200 Subject: [PATCH] first try --- nifty5/domain_tuple.py | 17 ++++++++-- nifty5/multi/multi_domain.py | 41 +++++++++++++++++++++++ nifty5/multi/multi_field.py | 15 +++++++-- nifty5/operators/linear_operator.py | 2 +- nifty5/operators/sum_operator.py | 21 +++++++----- test/test_minimization/test_minimizers.py | 2 +- 6 files changed, 82 insertions(+), 16 deletions(-) diff --git a/nifty5/domain_tuple.py b/nifty5/domain_tuple.py index 39eafa885..b164988ab 100644 --- a/nifty5/domain_tuple.py +++ b/nifty5/domain_tuple.py @@ -128,13 +128,24 @@ class DomainTuple(object): def __eq__(self, x): if not isinstance(x, DomainTuple): x = DomainTuple.make(x) - if self is x: - return True - return self._dom == x._dom + return self is x def __ne__(self, x): return not self.__eq__(x) + def compatibleTo(self, x): + return self.__eq__(x) + + def subsetOf(self, x): + return self.__eq__(x) + + def unitedWith(self, x): + if not isinstance(x, DomainTuple): + x = DomainTuple.make(x) + if self != x: + raise ValueError("domain mismatch") + return self + def __str__(self): res = "DomainTuple, len: " + str(len(self)) for i in self: diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py index 89272688b..24b7524a3 100644 --- a/nifty5/multi/multi_domain.py +++ b/nifty5/multi/multi_domain.py @@ -71,3 +71,44 @@ class MultiDomain(frozendict): obj = MultiDomain(domain, _callingfrommake=True) MultiDomain._domainCache[domain] = obj return obj + + def __eq__(self, x): + if not isinstance(x, MultiDomain): + x = MultiDomain.make(x) + return self is x + + def __ne__(self, x): + return not self.__eq__(x) + + def compatibleTo(self, x): + if not isinstance(x, MultiDomain): + x = MultiDomain.make(x) + commonKeys = set(self.keys()) & set(x.keys()) + for key in commonKeys: + if self[key] != x[key]: + return False + return True + + def subsetOf(self, x): + if not isinstance(x, MultiDomain): + x = MultiDomain.make(x) + for key in self.keys(): + if not key in x: + return False + if self[key] != x[key]: + return False + return True + + def unitedWith(self, x): + if not isinstance(x, MultiDomain): + x = MultiDomain.make(x) + if self == x: + return self + if not self.compatibleTo(x): + raise ValueError("domain mismatch") + res = {} + for key, val in self.items(): + res[key] = val + for key, val in x.items(): + res[key] = val + return MultiDomain.make(res) diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py index 8e96a6e24..fe182774f 100644 --- a/nifty5/multi/multi_field.py +++ b/nifty5/multi/multi_field.py @@ -199,9 +199,18 @@ for op in ["__add__", "__radd__", "__iadd__", def func(op): def func2(self, other): if isinstance(other, MultiField): - self._check_domain(other) - result_val = {key: getattr(sub_field, op)(other[key]) - for key, sub_field in self.items()} + if self._domain == other._domain: + result_val = {key: getattr(sub_field, op)(other[key]) + for key, sub_field in self.items()} + else: + if not self._domain.compatibleTo(other.domain): + raise ValueError("domain mismatch") + fullkeys = set(self._domain.keys()) | set(other._domain.keys()) + result_val = {} + for key in fullkeys: + f1 = self[key] if key in self._domain.keys() else other[key]*0 + f2 = other[key] if key in other._domain.keys() else self[key]*0 + result_val[key] = getattr(f1, op)(f2) else: result_val = {key: getattr(val, op)(other) for key, val in self.items()} diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index 8cc61dd32..0f8d0aa33 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -280,5 +280,5 @@ class LinearOperator(NiftyMetaBase()): def _check_input(self, x, mode): self._check_mode(mode) - if x.domain != self._dom(mode): + if not self._dom(mode).subsetOf(x.domain): raise ValueError("The operator's and field's domains don't match.") diff --git a/nifty5/operators/sum_operator.py b/nifty5/operators/sum_operator.py index d32d7bfd7..b2bf53eb3 100644 --- a/nifty5/operators/sum_operator.py +++ b/nifty5/operators/sum_operator.py @@ -23,12 +23,14 @@ import numpy as np class SumOperator(LinearOperator): """Class representing sums of operators.""" - def __init__(self, ops, neg, _callingfrommake=False): + def __init__(self, ops, neg, dom, tgt, _callingfrommake=False): if not _callingfrommake: raise NotImplementedError super(SumOperator, self).__init__() self._ops = ops self._neg = neg + self._domain = dom + self._target = tgt self._capability = self.TIMES | self.ADJOINT_TIMES for op in ops: self._capability &= op.capability @@ -38,9 +40,12 @@ class SumOperator(LinearOperator): from .scaling_operator import ScalingOperator from .diagonal_operator import DiagonalOperator # Step 1: verify domains + dom = ops[0].domain + tgt = ops[0].target for op in ops[1:]: - if op.domain != ops[0].domain or op.target != ops[0].target: - raise ValueError("domain mismatch") + dom = dom.unitedWith(op.domain) + tgt = tgt.unitedWith(op.target) + # Step 2: unpack SumOperators opsnew = [] negnew = [] @@ -124,7 +129,7 @@ class SumOperator(LinearOperator): negnew.append(neg[i]) ops = opsnew neg = negnew - return ops, neg + return ops, neg, dom, tgt @staticmethod def make(ops, neg): @@ -134,18 +139,18 @@ class SumOperator(LinearOperator): raise ValueError("ops is empty") if len(ops) != len(neg): raise ValueError("length mismatch between ops and neg") - ops, neg = SumOperator.simplify(ops, neg) + ops, neg, dom, tgt = SumOperator.simplify(ops, neg) if len(ops) == 1 and not neg[0]: return ops[0] - return SumOperator(ops, neg, _callingfrommake=True) + return SumOperator(ops, neg, dom, tgt, _callingfrommake=True) @property def domain(self): - return self._ops[0].domain + return self._domain @property def target(self): - return self._ops[0].target + return self._target @property def adjoint(self): diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index 060fe30b3..c1c7632cf 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -76,7 +76,7 @@ class Test_Minimizers(unittest.TestCase): except ImportError: raise SkipTest np.random.seed(42) - space = ift.UnstructuredDomain((2,)) + space = ift.DomainTuple.make(ift.UnstructuredDomain((2,))) starting_point = ift.Field.from_random('normal', domain=space)*10 class RBEnergy(ift.Energy): -- GitLab