Commit e227a907 authored by dboe's avatar dboe
Browse files

everything deepcopies now

parent 790752d9
......@@ -11,6 +11,10 @@ class Base_Check(object):
def demand_equal(self, other):
self.assertIsInstance(other, type(self._inst))
def demand_deep_copy(self, other):
self.demand_equal(other)
self.assertIsNot(self._inst, other)
def test_pickle(self):
with NamedTemporaryFile(suffix='.pickle') as out_file:
pickle.dump(self._inst,
......@@ -23,19 +27,16 @@ class Base_Check(object):
def test_deep_copy(self):
from copy import deepcopy
copy = deepcopy(self._inst)
self.demand_equal(copy)
self.assertIsNot(self._inst, copy)
other = deepcopy(self._inst)
self.demand_deep_copy(other)
def test_implicit_copy(self):
copy = type(self._inst)(self._inst)
self.demand_equal(copy)
self.assertIsNot(self._inst, copy)
other = type(self._inst)(self._inst)
self.demand_deep_copy(other)
def test_explicit_copy(self):
copy = self._inst.copy()
self.demand_equal(copy)
self.assertIsNot(self._inst, copy)
other = self._inst.copy()
self.demand_deep_copy(other)
def test_save_npz(self):
out_file = NamedTemporaryFile(suffix='.npz')
......@@ -190,6 +191,12 @@ class TensorFields_Check(Tensors_Check):
item.fields[i], np.array(self._inst.fields[i])[index]))
self.assertIsInstance(item.fields[i], check_type)
def demand_deep_copy(self, other):
super().demand_deep_copy(other)
self.assertIsNot(self._inst.fields, other.fields)
for i in range(len(self._inst.fields)):
self.assertIsNot(self._inst.fields[i], other.fields[i])
class TensorMaps_Check(TensorFields_Check):
def test_maps(self):
......@@ -206,6 +213,12 @@ class TensorMaps_Check(TensorFields_Check):
super().demand_index_equal(index, check_type)
# TODO: this is hard to check generically
def demand_deep_copy(self, other):
super().demand_deep_copy(other)
self.assertIsNot(self._inst.maps, other.maps)
for i in self._inst.maps:
self.assertIsNot(self._inst.maps[i], other.maps[i])
"""
EMPTY TESTS
......
......@@ -487,7 +487,8 @@ class AbstractNdarray(np.ndarray, AbstractObject):
# excluded from the __setstate__
# need to be in the same order as they
# have been added to __slots__
n_old = len(valid_slot_attrs) - len(state[5:])
n_np = 5 # number of numpy array states
n_old = len(valid_slot_attrs) - len(state[n_np:])
if n_old > 0:
for latest_index in range(n_old):
new_slot = added_slot_attrs[-latest_index]
......@@ -500,7 +501,7 @@ class AbstractNdarray(np.ndarray, AbstractObject):
setattr(self, new_slot, None)
for slot_index, slot in enumerate(valid_slot_attrs):
state_index = 5 + slot_index
state_index = n_np + slot_index
setattr(self, slot, state[state_index])
@property
......@@ -559,16 +560,21 @@ class AbstractNdarray(np.ndarray, AbstractObject):
>>> import tfields
>>> m = tfields.TensorMaps(
... [[1,2,3], [3,3,3], [0,0,0], [5,6,7]],
... [[1], [3], [0], [5]],
... maps=[
... ([[0, 1, 2], [1, 2, 3]], [21, 42]),
... [[1]],
... [[0, 1, 2, 3]]
... [[1]],
... [[0, 1, 2, 3]]
... ])
>>> mc = m.copy()
>>> mc.equal(m)
True
>>> mc is m
False
>>> mc.fields is m.fields
False
>>> mc.fields[0] is m.fields[0]
False
>>> mc.maps[3].fields[0] is m.maps[3].fields[0]
False
......@@ -1705,14 +1711,13 @@ def as_tensors_list(tensors_list):
tensors_list = new_list
return tensors_list
def as_maps(maps):
"""
Setter for TensorMaps.maps
Copies input
"""
# TODO: why not None?
if maps is not None and not isinstance(maps, Maps):
maps = Maps(maps)
maps = Maps(maps)
return maps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment