Commit aabb0370 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'subset_tweaks' into 'NIFTy_5'

Subset tweaks

See merge request ift/NIFTy!273
parents 6e2cbb65 cd9e78af
Pipeline #31788 failed with stages
in 4 minutes and 5 seconds
NIFTy - Numerical Information Field Theory
==========================================
[![build status](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_4/build.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_4)
[![coverage report](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_4/coverage.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_4)
[![build status](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_5/build.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_5)
[![coverage report](https://gitlab.mpcdf.mpg.de/ift/NIFTy/badges/NIFTy_5/coverage.svg)](https://gitlab.mpcdf.mpg.de/ift/NIFTy/commits/NIFTy_5)
**NIFTy** project homepage:
[http://ift.pages.mpcdf.de/NIFTy](http://ift.pages.mpcdf.de/NIFTy)
......@@ -62,7 +62,7 @@ distributions, the "apt" lines will need slight changes.
NIFTy5 and its mandatory dependencies can be installed via:
sudo apt-get install git libfftw3-dev python python-pip python-dev
pip install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_4
pip install --user git+https://gitlab.mpcdf.mpg.de/ift/NIFTy.git@NIFTy_5
(Note: If you encounter problems related to `pyFFTW`, make sure that you are
using a pip-installed `pyFFTW` package. Unfortunately, some distributions are
......
......@@ -68,8 +68,6 @@ class DomainTuple(object):
"""
if isinstance(domain, DomainTuple):
return domain
if isinstance(domain, dict):
return domain
domain = DomainTuple._parse_domain(domain)
obj = DomainTuple._tupleCache.get(domain)
if obj is not None:
......@@ -126,8 +124,9 @@ class DomainTuple(object):
return self._dom.__hash__()
def __eq__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self is x:
return True
x = DomainTuple.make(x)
return self is x
def __ne__(self, x):
......@@ -140,9 +139,10 @@ class DomainTuple(object):
return self.__eq__(x)
def unitedWith(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self != x:
if self is x:
return self
x = DomainTuple.make(x)
if self is not x:
raise ValueError("domain mismatch")
return self
......
......@@ -75,8 +75,11 @@ class Model(NiftyMetaBase()):
raise NotImplementedError
def __str__(self):
s = '--------------------------------------------------------------------------------\n'
s += '<Nifty Model at {}>\n\n'.format(hex(id(self)))
s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format(self.position.domain, self.value.domain)
s += '--------------------------------------------------------------------------------\n'
s = ('----------------------------------------'
'----------------------------------------\n'
'<Nifty Model at {}>\n\n'.format(hex(id(self))))
s += 'Position domain:\n{}\n\nValue domain:\n{}\n'.format(
self.position.domain, self.value.domain)
s += ('---------------------------------------'
'-----------------------------------------\n')
return s
......@@ -46,6 +46,8 @@ class frozendict(collections.Mapping):
class MultiDomain(frozendict):
_domainCache = {}
_subsetCache = set()
_compatCache = set()
def __init__(self, domain, _callingfrommake=False):
if not _callingfrommake:
......@@ -73,38 +75,53 @@ class MultiDomain(frozendict):
return obj
def __eq__(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if self is x:
return True
x = MultiDomain.make(x)
return self is x
def __ne__(self, x):
return not self.__eq__(x)
def __hash__(self):
return super(MultiDomain, self).__hash__()
def compatibleTo(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if self is x:
return True
x = MultiDomain.make(x)
if (self, x) in MultiDomain._compatCache:
return True
commonKeys = set(self.keys()) & set(x.keys())
for key in commonKeys:
if self[key] != x[key]:
if self[key] is not x[key]:
return False
MultiDomain._compatCache.add((self, x))
MultiDomain._compatCache.add((x, self))
return True
def subsetOf(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if self is x:
return True
x = MultiDomain.make(x)
if (self, x) in MultiDomain._subsetCache:
return True
if len(x) == 0:
MultiDomain._subsetCache.add((self, x))
return True
for key in self.keys():
if key not in x:
return False
if self[key] != x[key]:
if self[key] is not x[key]:
return False
MultiDomain._subsetCache.add((self, x))
return True
def unitedWith(self, x):
if not isinstance(x, MultiDomain):
x = MultiDomain.make(x)
if self == x:
if self is x:
return self
x = MultiDomain.make(x)
if self is x:
return self
if not self.compatibleTo(x):
raise ValueError("domain mismatch")
......
......@@ -179,22 +179,23 @@ class MultiField(object):
return True
if not isinstance(other, MultiField):
return False
for key, val in self._domain.items():
if key not in other._domain or other._domain[key] != val:
if len(set(self._domain.keys()) - set(other._domain.keys())) > 0:
return False
for key in self._domain.keys():
if other._domain[key] is not self._domain[key]:
return False
for key, val in self._val.items():
if not val.isSubsetOf(other[key]):
if not other[key].isSubsetOf(self[key]):
return False
return True
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
"__div__", "__rdiv__", "__idiv__",
"__truediv__", "__rtruediv__", "__itruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__",
"__pow__", "__rpow__", "__ipow__",
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
"__pow__", "__rpow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
......@@ -205,31 +206,28 @@ for op in ["__add__", "__radd__", "__iadd__",
else:
if not self._domain.compatibleTo(other.domain):
raise ValueError("domain mismatch")
fullkeys = set(self._domain.keys()) | set(other._domain.keys())
s1 = set(self._domain.keys())
s2 = set(other._domain.keys())
common_keys = s1 & s2
only_self_keys = s1 - s2
only_other_keys = s2 - s1
result_val = {}
if op in ["__iadd__", "__add__"]:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else None
f2 = other[key] if key in other._domain.keys() else None
if f1 is None:
result_val[key] = f2
elif f2 is None:
result_val[key] = f1
else:
result_val[key] = getattr(f1, op)(f2)
elif op in ["__mul__"]:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else None
f2 = other[key] if key in other._domain.keys() else None
if f1 is None or f2 is None:
continue
else:
result_val[key] = getattr(f1, op)(f2)
for key in common_keys:
result_val[key] = getattr(self[key], op)(other[key])
if op in ("__add__", "__radd__"):
for key in only_self_keys:
result_val[key] = self[key].copy()
for key in only_other_keys:
result_val[key] = other[key].copy()
elif op in ("__mul__", "__rmul__"):
pass
else:
for key in fullkeys:
f1 = self[key] if key in self._domain.keys() else other[key]*0
f2 = other[key] if key in other._domain.keys() else self[key]*0
result_val[key] = getattr(f1, op)(f2)
for key in only_self_keys:
result_val[key] = getattr(
self[key], op)(self[key]*0.)
for key in only_other_keys:
result_val[key] = getattr(
other[key]*0., op)(other[key])
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
......
......@@ -37,8 +37,6 @@ class Test_Functionality(unittest.TestCase):
assert_equal(f1.locked, False)
f1.lock()
assert_equal(f1.locked, True)
with assert_raises(ValueError):
f1 += f1
assert_equal(f1.locked_copy() is f1, True)
def test_fill(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment