Commit b2675597 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try to speedup domain checks

parent 6e2cbb65
Pipeline #31767 failed with stages
in 2 minutes and 23 seconds
......@@ -75,8 +75,11 @@ class Model(NiftyMetaBase()):
raise NotImplementedError
def __str__(self):
s = '--------------------------------------------------------------------------------\n'
s += '<Nifty Model at {}>\n\n'.format(hex(id(self)))
s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format(self.position.domain, self.value.domain)
s += '--------------------------------------------------------------------------------\n'
s = ('----------------------------------------'
'----------------------------------------\n'
'<Nifty Model at {}>\n\n'.format(hex(id(self))))
s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format(
self.position.domain, self.value.domain)
s += ('---------------------------------------'
'-----------------------------------------\n')
return s
......@@ -46,6 +46,8 @@ class frozendict(collections.Mapping):
class MultiDomain(frozendict):
_domainCache = {}
_subsetCache = set()
_compatCache = set()
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
......@@ -83,22 +85,30 @@ class MultiDomain(frozendict):
def compatibleTo(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if (self, x) in MultiDomain._compatCache:
return True
commonKeys = set(self.keys()) & set(x.keys())
for key in commonKeys:
if self[key] != x[key]:
return False
MultiDomain._compatCache.add((self, x))
MultiDomain._compatCache.add((x, self))
return True
def subsetOf(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if (self, x) in MultiDomain._subsetCache:
return True
if len(x) == 0:
MultiDomain._subsetCache.add((self, x))
return True
for key in self.keys():
if key not in x:
return False
if self[key] != x[key]:
return False
MultiDomain._subsetCache.add((self, x))
return True
def unitedWith(self, x):
......
......@@ -179,22 +179,23 @@ class MultiField(object):
return True
if not isinstance(other, MultiField):
return False
for key, val in self._domain.items():
if key not in other._domain or other._domain[key] != val:
if len(set(self._domain.keys()) - set(other._domain.keys())) > 0:
return False
for key, val in self._val.items():
if not val.isSubsetOf(other[key]):
for key in self._domain.keys():
if other._domain[key] != self._domain[key]:
return False
if not other[key].isSubsetOf(self[key]):
return False
return True
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
"__div__", "__rdiv__", "__idiv__",
"__truediv__", "__rtruediv__", "__itruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__",
"__pow__", "__rpow__", "__ipow__",
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
......@@ -205,31 +206,26 @@ for op in ["__add__", "__radd__", "__iadd__",
else:
if not self._domain.compatibleTo(other.domain):
raise ValueError("domain mismatch")
fullkeys = set(self._domain.keys()) | set(other._domain.keys())
s1 = set(self._domain.keys())
s2 = set(other._domain.keys())
common_keys = s1 & s2
only_self_keys = s1 - s2
only_other_keys = s2 - s1
result_val = {}
if op in ["__iadd__", "__add__"]:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else None
f2 = other[key] if key in other._domain.keys() else None
if f1 is None:
result_val[key] = f2
elif f2 is None:
result_val[key] = f1
else:
result_val[key] = getattr(f1, op)(f2)
elif op in ["__mul__"]:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else None
f2 = other[key] if key in other._domain.keys() else None
if f1 is None or f2 is None:
continue
else:
result_val[key] = getattr(f1, op)(f2)
for key in common_keys:
result_val[key] = getattr(self[key], op)(other[key])
if op in ("__add__", "__radd__"):
for key in only_self_keys:
result_val[key] = self[key].copy()
for key in only_other_keys:
result_val[key] = other[key].copy()
elif op in ("__mul__", "__rmul__"):
pass
else:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else other[key]*0
f2 = other[key] if key in other._domain.keys() else self[key]*0
result_val[key] = getattr(f1, op)(f2)
for key in only_self_keys:
result_val[key] = getattr(self[key], op)(self[key]*0.)
for key in only_other_keys:
result_val[key] = getattr(other[key]*0., op)(other[key])
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment