Commit 3d85e742 authored by Daniel Böckenhoff's avatar Daniel Böckenhoff
Browse files

copy, rewritten getitem and setitem, copyconstructors

parent ed4777c7
......@@ -162,6 +162,38 @@ class AbstractNdarray(np.ndarray):
index = -(i + 1)
setattr(self, slot, state[index])
def copy(self, *args, **kwargs):
"""
The standard ndarray copy does not copy slots. Correct for this.
Examples:
>>> import tfields
>>> m = tfields.TensorMaps([[1,2,3], [3,3,3], [0,0,0], [5,6,7]],
... maps=[tfields.TensorFields([[0, 1, 2], [1, 2, 3]],
... [1, 2])])
>>> mc = m.copy()
>>> mc is m
False
>>> mc.maps[0].fields[0] is m.maps[0].fields[0]
False
TODO: This function implementation could be more general or maybe redirect to deepcopy?
"""
inst = super(AbstractNdarray, self).copy(*args, **kwargs)
for attr in self._iter_slots():
value = getattr(self, attr)
if hasattr(value, 'copy'):
setattr(inst, attr, value.copy(*args, **kwargs))
elif isinstance(value, list):
list_copy = []
for item in value:
if hasattr(item, 'copy'):
list_copy.append(item.copy(*args, **kwargs))
else:
list_copy.append(item)
setattr(inst, attr, list_copy)
return inst
def save(self, path, *args, **kwargs):
"""
Saving a tensors object by redirecting to the correct save method depending on path
......@@ -332,7 +364,7 @@ class Tensors(AbstractNdarray):
if issubclass(type(tensors), cls):
obj = tensors.copy()
if dtype != obj.dtype or order is not None:
obj.astype(dtype, order=order)
obj = obj.astype(dtype, order=order)
coordSys = kwargs.pop('coordSys', None)
if kwargs:
raise AttributeError("In copy constructor only 'dtype' and 'coordSys' "
......@@ -359,7 +391,10 @@ class Tensors(AbstractNdarray):
''' check dimension(s) '''
for d in obj.shape[1:]:
if not d == obj.dim:
raise ValueError("Dimensions are inconstistent.")
raise ValueError("Dimensions are inconstistent. "
"Manifold dimension is {obj.dim}, "
"Found dimensions {found} in {obj}."
.format(found=obj.shape[1:], **locals()))
if dim is not None:
if dim != obj.dim:
raise ValueError("Incorrect dimension: {obj.dim} given,"
......@@ -945,16 +980,38 @@ class TensorFields(Tensors):
>>> multiField.fields[1].dim
3
Empty initialization
>>> empty_field = TensorFields([], dim=3)
>>> assert empty_field.shape == (0, 3)
>>> assert empty_field.fields == []
Directly initializing with lists or arrays
>>> vec_field_raw = tfields.TensorFields([[0, 1, 2], [3, 4, 5]],
... [1, 6], [2, 7])
>>> assert len(vec_field_raw.fields) == 2
Copying
>>> cp = TensorFields(vectorField)
>>> assert vectorField.equal(cp)
Copying with changing type
>>> tcp = TensorFields(vectorField, dtype=int)
>>> assert vectorField.equal(tcp)
>>> assert tcp.dtype == int
"""
__slots__ = ['coordSys', 'fields']
def __new__(cls, tensors, *fields, **kwargs):
obj = super(TensorFields, cls).__new__(cls, tensors, **kwargs)
if fields or (issubclass(type(tensors), cls) and fields):
if issubclass(type(tensors), TensorFields):
if tensors.fields is None:
raise ValueError("Tensor fields were None")
obj.fields = [Tensors(field) for field in tensors.fields]
elif not fields:
obj.fields = []
if fields:
# (over)write fields
obj.fields = [Tensors(field) for field in fields]
return obj
......@@ -985,13 +1042,14 @@ class TensorFields(Tensors):
"""
item = super(TensorFields, self).__getitem__(index)
if isinstance(item, TensorFields):
if issubclass(type(item), TensorFields):
if isinstance(index, slice):
item.fields = [field.__getitem__(index) for field in item.fields]
elif isinstance(index, tuple):
item.fields = [field.__getitem__(index[0]) for field in item.fields]
else:
item.fields = [field.__getitem__(index) for field in item.fields]
if item.fields:
item.fields = [field.__getitem__(index) for field in item.fields]
return item
......@@ -1016,7 +1074,7 @@ class TensorFields(Tensors):
"""
super(TensorFields, self).__setitem__(index, item)
if isinstance(item, TensorFields):
if issubclass(type(item), TensorFields):
if isinstance(index, slice):
for i, field in enumerate(item.fields):
self.fields[i].__setitem__(index, field)
......@@ -1036,6 +1094,7 @@ class TensorFields(Tensors):
"""
Test, whether the instance has the same content as other.
Args:
other (iterable)
optional:
see Tensors.equal
"""
......@@ -1045,8 +1104,11 @@ class TensorFields(Tensors):
with other.tmp_transform(self.coordSys):
mask = super(TensorFields, self).equal(other, **kwargs)
if issubclass(type(other), TensorFields):
for i, field in enumerate(self.fields):
mask &= field.equal(other.fields[i], **kwargs)
if len(self.fields) != len(other.fields):
mask &= False
else:
for i, field in enumerate(self.fields):
mask &= field.equal(other.fields[i], **kwargs)
return mask
......@@ -1060,7 +1122,7 @@ class TensorMaps(TensorFields):
maps (array-like): indices indicating a connection between the
tensors at the respective index positions
Examples:
>>> from tfields import Tensors, TensorMaps
>>> from tfields import Tensors, TensorFields, TensorMaps
>>> scalars = Tensors([0, 1, 2])
>>> vectors = Tensors([[0, 0, 0], [0, 0, 1], [0, -1, 0]])
>>> maps = [TensorFields([[0, 1, 2], [0, 1, 2]], [42, 21]),
......@@ -1070,8 +1132,8 @@ class TensorMaps(TensorFields):
>>> assert isinstance(mesh.maps, list)
>>> assert len(mesh.maps) == 2
>>> print mesh.fields
>>> print mesh.maps[0].fields
>>> assert mesh.equal(TensorFields(vectors, scalars))
>>> assert mesh.maps[0].fields[0].equal(maps[0].fields[0])
"""
__slots__ = ['coordSys', 'fields', 'maps']
......@@ -1080,23 +1142,52 @@ class TensorMaps(TensorFields):
maps = kwargs.pop('maps', [])
maps_cp = []
for mp in maps:
try:
mp = TensorFields(mp, dtype=int)
except Exception as err:
raise ValueError("Could not cast map {mp} to TensorFields instance."
" Error '{err}' occured."
.format(**locals()))
mp = TensorFields(mp, dtype=int)
maps_cp.append(mp)
kwargs['maps'] = maps_cp
obj = super(TensorMaps, cls).__new__(cls, tensors, *fields, **kwargs)
return obj
def equal(self, other, **kwargs):
"""
Test, whether the instance has the same content as other.
Args:
other (iterable)
optional:
see TensorFields.equal
Examples:
>>> import tfields
>>> maps = [tfields.TensorFields([[1]], [42])]
>>> tm = tfields.TensorMaps(maps[0], maps=maps)
# >>> assert tm.equal(tm)
>>> cp = tm.copy()
# >>> assert tm.equal(cp)
>>> cp.maps[0].fields[0] = -42
>>> assert tm.maps[0].fields[0] == 42
>>> assert not tm.equal(cp)
"""
if not issubclass(type(other), Tensors):
return super(TensorMaps, self).equal(other, **kwargs)
else:
with other.tmp_transform(self.coordSys):
mask = super(TensorMaps, self).equal(other, **kwargs)
if issubclass(type(other), TensorMaps):
if len(self.maps) != len(other.maps):
mask &= False
else:
for i, mp in enumerate(self.maps):
mask &= mp.equal(other.maps[i], **kwargs)
return mask
def stale(self):
"""
Returns:
Mask for all vertices that are stale i.e. are not refered by maps
Examples:
>>> from tfields import Tensors, TensorMaps
>>> from tfields import Tensors, TensorFields, TensorMaps
>>> vectors = Tensors([[0, 0, 0], [0, 0, 1], [0, -1, 0], [4, 4, 4]])
>>> tm = TensorMaps(vectors, maps=[[[0, 1, 2], [0, 1, 2]],
... [[1, 1], [2, 2]]])
......@@ -1118,8 +1209,9 @@ class TensorMaps(TensorFields):
Examples:
>>> import tfields
>>> mp1 = tfields.TensorFields([[0, 1, 2], [3, 4, 5]],
... [[1,2,3,4,5], [6,7,8,9,0]])
... *zip([1,2,3,4,5], [6,7,8,9,0]))
>>> mp2 = tfields.TensorFields([[0], [3]])
>>> tm = tfields.TensorMaps([[0,0,0], [1,1,1], [2,2,2], [0,0,0],
... [3,3,3], [4,4,4], [5,6,7]],
... maps=[mp1, mp2])
......@@ -1172,8 +1264,7 @@ class TensorMaps(TensorFields):
... maps=[TensorFields([[0, 1, 2], [0, 1, 3],
... [3, 4, 5], [3, 4, 1],
... [3, 4, 6]],
... [[1,2], [3,4], [5,6], [7,8], [9,0]])])
>>> m.maps[0].fields
... [1, 3, 5, 7, 9], [2, 4, 6, 8, 0])])
>>> c = m.removed([True, True, True, False, False, False, False])
>>> c
TensorMaps([[ 0., 0., 0.],
......@@ -1181,11 +1272,11 @@ class TensorMaps(TensorFields):
[ 4., 4., 4.],
[ 5., 5., 5.]])
>>> c.maps[0]
array([[0, 1, 2],
[0, 1, 3]])
TensorFields([[0, 1, 2],
[0, 1, 3]])
>>> c.maps[0].fields
array([[ 5., 6.],
[ 9., 0.]])
[TensorFields([[ 5., 6.],
[ 9., 0.]])]
"""
remove_condition = np.array(remove_condition)
......@@ -1238,7 +1329,8 @@ class TensorMaps(TensorFields):
if __name__ == '__main__': # pragma: no cover
import doctest
# doctest.testmod()
# doctest.run_docstring_examples(TensorMaps.cleaned, globals())
doctest.run_docstring_examples(TensorMaps, globals())
doctest.testmod()
doctest.run_docstring_examples(TensorMaps.cleaned, globals())
doctest.run_docstring_examples(TensorFields, globals())
doctest.run_docstring_examples(AbstractNdarray.copy, globals())
doctest.run_docstring_examples(TensorMaps.equal, globals())
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment