Commit f4ed25ce authored by dboe's avatar dboe
Browse files

deepcopy for copying

parent f89789a0
......@@ -21,10 +21,16 @@ class Base_Check(object):
self.demand_equal(reloaded)
def test_deep_copy(self):
from copy import deepcopy
copy = deepcopy(self._inst)
self.demand_equal(copy)
self.assertIsNot(self._inst, copy)
def test_implicit_copy(self):
copy = type(self._inst)(self._inst)
self.demand_equal(copy)
# self.assertIsNot(self._inst, copy)
self.assertIsNot(self._inst, copy)
def test_explicit_copy(self):
copy = self._inst.copy()
......
......@@ -18,6 +18,7 @@ import pathlib
from six import string_types
from contextlib import contextmanager
from collections import Counter
from copy import deepcopy
# 3rd party
import numpy as np
......@@ -571,25 +572,9 @@ class AbstractNdarray(np.ndarray, AbstractObject):
>>> mc.maps[3].fields[0] is m.maps[3].fields[0]
False
TODO:
This function implementation could be more general or maybe
redirect to deepcopy?
"""
inst = super().copy(*args, **kwargs)
for attr in self._iter_slots():
value = getattr(self, attr)
if hasattr(value, "copy") and not isinstance(value, list):
setattr(inst, attr, value.copy(*args, **kwargs))
elif isinstance(value, list):
list_copy = []
for item in value:
if hasattr(item, "copy"):
list_copy.append(item.copy(*args, **kwargs))
else:
list_copy.append(item)
setattr(inst, attr, list_copy)
return inst
# works with __reduce__ / __setstate__
return deepcopy(self)
class Tensors(AbstractNdarray):
......
......@@ -192,15 +192,18 @@ class Mesh3D(tfields.TensorMaps):
"""
def __new__(cls, tensors, *fields, **kwargs):
if not issubclass(type(tensors), Mesh3D):
kwargs['dim'] = 3
kwargs['dim'] = 3
if 'maps' in kwargs and 'faces' in kwargs:
raise ValueError("Conflicting options maps and faces")
faces = kwargs.pop('faces', None)
if not faces: # None or []
maps = kwargs.pop('maps', None)
if faces is not None and not faces:
# faces = []
faces = np.empty((0, 3))
maps = kwargs.pop('maps', [faces])
kwargs['maps'] = maps
if faces is not None:
maps = [faces]
if maps is not None:
kwargs['maps'] = maps
obj = super(Mesh3D, cls).__new__(cls, tensors, *fields, **kwargs)
if len(obj.maps) > 1:
raise ValueError("Mesh3D only allows one map")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment