diff --git a/nifty5/models/model.py b/nifty5/models/model.py index d3f2b5d65bde9ad00ce95ea0e19c5d3c56f9c827..d6de3b32214575ee886fd555b515079bf7d63318 100644 --- a/nifty5/models/model.py +++ b/nifty5/models/model.py @@ -75,8 +75,11 @@ class Model(NiftyMetaBase()): raise NotImplementedError def __str__(self): - s = '--------------------------------------------------------------------------------\n' - s += '<Nifty Model at {}>\n\n'.format(hex(id(self))) - s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format(self.position.domain, self.value.domain) - s += '--------------------------------------------------------------------------------\n' + s = ('----------------------------------------' + '----------------------------------------\n' + '<Nifty Model at {}>\n\n'.format(hex(id(self)))) + s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format( + self.position.domain, self.value.domain) + s += ('---------------------------------------' + '-----------------------------------------\n') return s diff --git a/nifty5/multi/multi_domain.py b/nifty5/multi/multi_domain.py index 5471a95afee6c7aa460b5a8bfd8d1e5583a3c03a..366a9b965f3183416c11142776803f2dee610574 100644 --- a/nifty5/multi/multi_domain.py +++ b/nifty5/multi/multi_domain.py @@ -46,6 +46,8 @@ class frozendict(collections.Mapping): class MultiDomain(frozendict): _domainCache = {} + _subsetCache = set() + _compatCache = set() def __init__(self, domain, _callingfrommake=False): if not _callingfrommake: @@ -83,22 +85,30 @@ class MultiDomain(frozendict): def compatibleTo(self, x): if not isinstance(x, MultiDomain): x = MultiDomain.make(x) + if (self, x) in MultiDomain._compatCache: + return True commonKeys = set(self.keys()) & set(x.keys()) for key in commonKeys: if self[key] != x[key]: return False + MultiDomain._compatCache.add((self, x)) + MultiDomain._compatCache.add((x, self)) return True def subsetOf(self, x): if not isinstance(x, MultiDomain): x = MultiDomain.make(x) + if (self, x) in MultiDomain._subsetCache: + return True if len(x) == 0: + MultiDomain._subsetCache.add((self, x)) return True for key in self.keys(): if key not in x: return False if self[key] != x[key]: return False + MultiDomain._subsetCache.add((self, x)) return True def unitedWith(self, x): diff --git a/nifty5/multi/multi_field.py b/nifty5/multi/multi_field.py index 7e852de634d777fae5f19133e49d04dad3fcc4a5..79c73a40d696f62728492a8f95cf54088cdde499 100644 --- a/nifty5/multi/multi_field.py +++ b/nifty5/multi/multi_field.py @@ -179,22 +179,23 @@ class MultiField(object): return True if not isinstance(other, MultiField): return False - for key, val in self._domain.items(): - if key not in other._domain or other._domain[key] != val: + if len(set(self._domain.keys()) - set(other._domain.keys())) > 0: + return False + for key in self._domain.keys(): + if other._domain[key] != self._domain[key]: return False - for key, val in self._val.items(): - if not val.isSubsetOf(other[key]): + if not other[key].isSubsetOf(self[key]): return False return True -for op in ["__add__", "__radd__", "__iadd__", - "__sub__", "__rsub__", "__isub__", - "__mul__", "__rmul__", "__imul__", - "__div__", "__rdiv__", "__idiv__", - "__truediv__", "__rtruediv__", "__itruediv__", - "__floordiv__", "__rfloordiv__", "__ifloordiv__", - "__pow__", "__rpow__", "__ipow__", +for op in ["__add__", "__radd__", + "__sub__", "__rsub__", + "__mul__", "__rmul__", + "__div__", "__rdiv__", + "__truediv__", "__rtruediv__", + "__floordiv__", "__rfloordiv__", + "__pow__", "__rpow__", "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: def func(op): def func2(self, other): @@ -205,31 +206,26 @@ for op in ["__add__", "__radd__", "__iadd__", else: if not self._domain.compatibleTo(other.domain): raise ValueError("domain mismatch") - fullkeys = set(self._domain.keys()) | set(other._domain.keys()) + s1 = set(self._domain.keys()) + s2 = set(other._domain.keys()) + common_keys = s1 & s2 + only_self_keys = s1 - s2 + only_other_keys = s2 - s1 result_val = {} - if op in ["__iadd__", "__add__"]: - for key in fullkeys: - f1 = self[key] if key in self._domain.keys() else None - f2 = other[key] if key in other._domain.keys() else None - if f1 is None: - result_val[key] = f2 - elif f2 is None: - result_val[key] = f1 - else: - result_val[key] = getattr(f1, op)(f2) - elif op in ["__mul__"]: - for key in fullkeys: - f1 = self[key] if key in self._domain.keys() else None - f2 = other[key] if key in other._domain.keys() else None - if f1 is None or f2 is None: - continue - else: - result_val[key] = getattr(f1, op)(f2) + for key in common_keys: + result_val[key] = getattr(self[key], op)(other[key]) + if op in ("__add__", "__radd__"): + for key in only_self_keys: + result_val[key] = self[key].copy() + for key in only_other_keys: + result_val[key] = other[key].copy() + elif op in ("__mul__", "__rmul__"): + pass else: - for key in fullkeys: - f1 = self[key] if key in self._domain.keys() else other[key]*0 - f2 = other[key] if key in other._domain.keys() else self[key]*0 - result_val[key] = getattr(f1, op)(f2) + for key in only_self_keys: + result_val[key] = getattr(self[key], op)(self[key]*0.) + for key in only_other_keys: + result_val[key] = getattr(other[key]*0., op)(other[key]) else: result_val = {key: getattr(val, op)(other) for key, val in self.items()}