Commit d524becb authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

setter for fields implemented

parent 33e4f169
......@@ -163,12 +163,31 @@ class AbstractNdarray(np.ndarray):
Counterpart to __reduce__. Important for unpickling.
"""
# Call the parent's __setstate__ with the other tuple elements.
super(AbstractNdarray, self).__setstate__(state[0:-len(self._iter_slots())])
# numpy ndarray state has 5 entries
super(AbstractNdarray, self).__setstate__(state[:5])
# set the __slot__ attributes
for i, slot in enumerate(reversed(self._iter_slots())):
index = -(i + 1)
setattr(self, slot, state[index])
valid_slot_attrs = list(self._iter_slots())
added_slot_attrs = ['name'] # attributes that have been added later
# have not been pickled with the full
# information and thus need to be
# excluded from the __setstate__
# need to be in the same order as they have
# been added to __slots__
n_old = len(valid_slot_attrs) - len(state[5:])
if n_old > 0:
for latest_index in range(n_old):
new_slot = added_slot_attrs[-latest_index]
warnings.warn("Slots with names '{new_slot}' appears to have been"
"added after the creation of the reduced state. "
"No corresponding state found in __setstate__."
.format(**locals()))
valid_slot_attrs.pop(valid_slot_attrs.index(new_slot))
setattr(self, new_slot, None)
for slot_index, slot in enumerate(valid_slot_attrs):
state_index = 5 + slot_index
setattr(self, slot, state[state_index])
@property
def bulk(self):
......@@ -178,6 +197,25 @@ class AbstractNdarray(np.ndarray):
"""
return np.array(self)
@classmethod
@contextmanager
def _bypass_setter(cls, slot, demand_existence=False):
"""
Temporarily remove the setter in __slot_setters__ corresponding to slot
position in __slot__. You should know what you do, when using this.
"""
slot_index = cls.__slots__.index(slot) if slot in cls.__slots__ else None
if slot_index is None:
if demand_existence:
raise ValueError("Slot {slot} not existing".format(**locals()))
else:
yield
return
setter = cls.__slot_setters__[slot_index]
cls.__slot_setters__[slot_index] = None
yield
cls.__slot_setters__[slot_index] = setter
def copy(self, *args, **kwargs):
"""
The standard ndarray copy does not copy slots. Correct for this.
......@@ -379,18 +417,19 @@ class AbstractNdarray(np.ndarray):
bulk_type = getattr(tfields, bulk_type)
list_dict[key].append(bulk_type._from_dict(**sub_dict[index]))
'''
Build the normal way
'''
bulk = kwargs.pop('bulk')
bulk_type = kwargs.pop('bulk_type')
obj = cls.__new__(cls, bulk, **kwargs)
'''
Set list attributes
'''
for attr, list_value in list_dict.items():
setattr(obj, attr, list_value)
with cls._bypass_setter('fields'):
'''
Build the normal way
'''
bulk = kwargs.pop('bulk')
bulk_type = kwargs.pop('bulk_type')
obj = cls.__new__(cls, bulk, **kwargs)
'''
Set list attributes
'''
for attr, list_value in list_dict.items():
setattr(obj, attr, list_value)
return obj
......@@ -1287,6 +1326,33 @@ class Tensors(AbstractNdarray):
return artist
def as_tensors_list(tensors_list):
"""
Setter for TensorFields.fields
Copies input
Examples:
>>> import tfields
>>> import numpy as np
>>> scalars = tfields.Tensors([0, 1, 2])
>>> vectors = tfields.Tensors([[0, 0, 0], [0, 0, 1], [0, -1, 0]])
>>> maps = [tfields.TensorFields([[0, 1, 2], [0, 1, 2]]),
... tfields.TensorFields([[1], [2]], [-42, -21])]
>>> mesh = tfields.TensorMaps(vectors, scalars,
... maps=maps)
>>> mesh.maps[0].fields = [[42, 21]]
>>> assert len(mesh.maps[0].fields) == 1
>>> assert mesh.maps[0].fields[0].equal([42, 21])
"""
if tensors_list is not None:
new_list = []
for tensors in tensors_list:
tensors_list = Tensors(tensors)
new_list.append(tensors_list)
tensors_list = new_list
return tensors_list
class TensorFields(Tensors):
"""
Discrete Tensor Field
......@@ -1357,6 +1423,9 @@ class TensorFields(Tensors):
"""
__slots__ = ['coord_sys', 'name', 'fields']
__slot_setters__ = [tfields.bases.get_coord_system_name,
None,
as_tensors_list]
def __new__(cls, tensors, *fields, **kwargs):
rigid = kwargs.pop('rigid', True)
......@@ -1365,12 +1434,12 @@ class TensorFields(Tensors):
if issubclass(type(tensors), TensorFields):
if tensors.fields is None:
raise ValueError("Tensor fields were None")
obj.fields = [Tensors(field) for field in tensors.fields]
obj.fields = tensors.fields
elif not fields:
obj.fields = []
if fields:
# (over)write fields
obj.fields = [Tensors(field) for field in fields]
obj.fields = fields
if rigid:
olen = len(obj)
......@@ -1389,7 +1458,10 @@ class TensorFields(Tensors):
>>> import tfields
>>> import numpy as np
>>> vectors = tfields.Tensors([[0, 0, 0], [0, 0, 1], [0, -1, 0]])
>>> scalar_field = tfields.TensorFields(vectors, [42, 21, 10.5], [1, 2, 3])
>>> scalar_field = tfields.TensorFields(vectors,
... [42, 21, 10.5],
... [1, 2, 3],
... [[0, 0], [-1, -1], [-2, -2]])
Slicing
>>> sliced = scalar_field[2:]
......@@ -1400,6 +1472,7 @@ class TensorFields(Tensors):
Picking
>>> picked = scalar_field[1]
>>> assert np.array_equal(picked, [0, 0, 1])
>>> assert np.array_equal(picked.fields[0], 21)
Masking
>>> masked = scalar_field[[True, False, True]]
......@@ -1417,7 +1490,9 @@ class TensorFields(Tensors):
if isinstance(index, tuple):
index = index[0]
if item.fields:
item.fields = [field.__getitem__(index) for field in item.fields]
# circumvent the setter here.
with self._bypass_setter('fields'):
item.fields = [field.__getitem__(index) for field in item.fields]
except IndexError as err:
warnings.warn("Index error occured for field.__getitem__. Error "
"message: {err}".format(**locals()))
......@@ -1497,7 +1572,8 @@ class TensorFields(Tensors):
@names.setter
def names(self, names):
if not len(names) == len(self.fields):
raise ValueError("len(names) != len(fields)")
raise ValueError("len(names) ({0}) != len(fields) ({1})"
.format(len(names), len(self.fields)))
for i, name in enumerate(names):
self.fields[i].name = name
......@@ -2037,6 +2113,7 @@ class Container(AbstractNdarray):
if __name__ == '__main__': # pragma: no cover
import doctest
doctest.testmod()
# doctest.run_docstring_examples(as_tensors_list, globals())
# doctest.run_docstring_examples(Tensors._save_npz, globals())
# doctest.run_docstring_examples(TensorMaps.cut, globals())
# doctest.run_docstring_examples(AbstractNdarray._save_npz, globals())
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment