Commit b2675597 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try to speedup domain checks

parent 6e2cbb65
Pipeline #31767 failed with stages
in 2 minutes and 23 seconds
...@@ -75,8 +75,11 @@ class Model(NiftyMetaBase()): ...@@ -75,8 +75,11 @@ class Model(NiftyMetaBase()):
raise NotImplementedError raise NotImplementedError
def __str__(self): def __str__(self):
s = '--------------------------------------------------------------------------------\n' s = ('----------------------------------------'
s += '<Nifty Model at {}>\n\n'.format(hex(id(self))) '----------------------------------------\n'
s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format(self.position.domain, self.value.domain) '<Nifty Model at {}>\n\n'.format(hex(id(self))))
s += '--------------------------------------------------------------------------------\n' s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format(
self.position.domain, self.value.domain)
s += ('---------------------------------------'
'-----------------------------------------\n')
return s return s
...@@ -46,6 +46,8 @@ class frozendict(collections.Mapping): ...@@ -46,6 +46,8 @@ class frozendict(collections.Mapping):
class MultiDomain(frozendict): class MultiDomain(frozendict):
_domainCache = {} _domainCache = {}
_subsetCache = set()
_compatCache = set()
def __init__(self, domain, _callingfrommake=False): def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake: if not _callingfrommake:
...@@ -83,22 +85,30 @@ class MultiDomain(frozendict): ...@@ -83,22 +85,30 @@ class MultiDomain(frozendict):
def compatibleTo(self, x): def compatibleTo(self, x):
if not isinstance(x, MultiDomain): if not isinstance(x, MultiDomain):
x = MultiDomain.make(x) x = MultiDomain.make(x)
if (self, x) in MultiDomain._compatCache:
return True
commonKeys = set(self.keys()) & set(x.keys()) commonKeys = set(self.keys()) & set(x.keys())
for key in commonKeys: for key in commonKeys:
if self[key] != x[key]: if self[key] != x[key]:
return False return False
MultiDomain._compatCache.add((self, x))
MultiDomain._compatCache.add((x, self))
return True return True
def subsetOf(self, x): def subsetOf(self, x):
if not isinstance(x, MultiDomain): if not isinstance(x, MultiDomain):
x = MultiDomain.make(x) x = MultiDomain.make(x)
if (self, x) in MultiDomain._subsetCache:
return True
if len(x) == 0: if len(x) == 0:
MultiDomain._subsetCache.add((self, x))
return True return True
for key in self.keys(): for key in self.keys():
if key not in x: if key not in x:
return False return False
if self[key] != x[key]: if self[key] != x[key]:
return False return False
MultiDomain._subsetCache.add((self, x))
return True return True
def unitedWith(self, x): def unitedWith(self, x):
......
...@@ -179,22 +179,23 @@ class MultiField(object): ...@@ -179,22 +179,23 @@ class MultiField(object):
return True return True
if not isinstance(other, MultiField): if not isinstance(other, MultiField):
return False return False
for key, val in self._domain.items(): if len(set(self._domain.keys()) - set(other._domain.keys())) > 0:
if key not in other._domain or other._domain[key] != val: return False
for key in self._domain.keys():
if other._domain[key] != self._domain[key]:
return False return False
for key, val in self._val.items(): if not other[key].isSubsetOf(self[key]):
if not val.isSubsetOf(other[key]):
return False return False
return True return True
for op in ["__add__", "__radd__", "__iadd__", for op in ["__add__", "__radd__",
"__sub__", "__rsub__", "__isub__", "__sub__", "__rsub__",
"__mul__", "__rmul__", "__imul__", "__mul__", "__rmul__",
"__div__", "__rdiv__", "__idiv__", "__div__", "__rdiv__",
"__truediv__", "__rtruediv__", "__itruediv__", "__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__", "__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__", "__ipow__", "__pow__", "__rpow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op): def func(op):
def func2(self, other): def func2(self, other):
...@@ -205,31 +206,26 @@ for op in ["__add__", "__radd__", "__iadd__", ...@@ -205,31 +206,26 @@ for op in ["__add__", "__radd__", "__iadd__",
else: else:
if not self._domain.compatibleTo(other.domain): if not self._domain.compatibleTo(other.domain):
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
fullkeys = set(self._domain.keys()) | set(other._domain.keys()) s1 = set(self._domain.keys())
s2 = set(other._domain.keys())
common_keys = s1 & s2
only_self_keys = s1 - s2
only_other_keys = s2 - s1
result_val = {} result_val = {}
if op in ["__iadd__", "__add__"]: for key in common_keys:
for key in fullkeys: result_val[key] = getattr(self[key], op)(other[key])
f1 = self[key] if key in self._domain.keys() else None if op in ("__add__", "__radd__"):
f2 = other[key] if key in other._domain.keys() else None for key in only_self_keys:
if f1 is None: result_val[key] = self[key].copy()
result_val[key] = f2 for key in only_other_keys:
elif f2 is None: result_val[key] = other[key].copy()
result_val[key] = f1 elif op in ("__mul__", "__rmul__"):
else: pass
result_val[key] = getattr(f1, op)(f2)
elif op in ["__mul__"]:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else None
f2 = other[key] if key in other._domain.keys() else None
if f1 is None or f2 is None:
continue
else:
result_val[key] = getattr(f1, op)(f2)
else: else:
for key in fullkeys: for key in only_self_keys:
f1 = self[key] if key in self._domain.keys() else other[key]*0 result_val[key] = getattr(self[key], op)(self[key]*0.)
f2 = other[key] if key in other._domain.keys() else self[key]*0 for key in only_other_keys:
result_val[key] = getattr(f1, op)(f2) result_val[key] = getattr(other[key]*0., op)(other[key])
else: else:
result_val = {key: getattr(val, op)(other) result_val = {key: getattr(val, op)(other)
for key, val in self.items()} for key, val in self.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment