Commit 24e39747 authored by dboe's avatar dboe
Browse files

new Maps instantiation allows empty template tensors

parent dab8d7b9
......@@ -340,6 +340,23 @@ class Maps_Test(Base_Check, unittest.TestCase):
[[42]]])
class Maps_Init_Test(Maps_Test):
def setUp(self):
self._inst = tfields.Maps({3: [[0, 1, 2]], 0: [[]]})
class Maps_Rigid_Test(Maps_Test):
def setUp(self):
rig = tfields.Maps({0: [[1, 2, 42]], 3: [1]})
self.assertIsInstance(rig[0], tfields.TensorFields)
self.assertIsInstance(rig[3], tfields.TensorFields)
self.assertEqual(tfields.dim(rig[0]), 3)
self.assertEqual(tfields.dim(rig[3]), 1)
self.assertEqual(tfields.rank(rig[0]), 1)
self.assertEqual(tfields.rank(rig[3]), 0)
self._inst = rig
class Container_Check(AbstractNdarray_Check):
def demand_equal(self, other):
super().demand_equal(other)
......
......@@ -11,6 +11,7 @@ class Base_Check(object):
self.assertTrue(merged_without_templates.equal(merged))
self.assertTrue(len(templates), len(self._instances))
for template, inst in zip(templates, self._instances):
self.assertEqual(template.dim, 0) # merging templates save meory
cut = merged.cut(template)
self.demand_equal_cut(inst, template, cut)
......@@ -38,13 +39,10 @@ class TensorMaps_Empty_Test(Base_Check, unittest.TestCase):
self.assertEqual(len(cut.maps), len(template.maps))
for mp_dim, mp in inst.maps.items():
cut_map = cut.maps[mp_dim]
template_map = template.maps[mp_dim]
self.assertEqual(len(mp),
len(cut_map))
self.assertEqual(tfields.dim(mp),
tfields.dim(cut_map))
self.assertEqual(tfields.dim(template_map),
tfields.dim(cut_map))
self.assertTrue(inst.equal(cut)) # most important
......
......@@ -880,18 +880,19 @@ class Tensors(AbstractNdarray):
# first object with some entries
kwargs['dim'] = dim(objects[0])
inst = cls.__new__(cls, tensors, **kwargs)
if not return_templates:
return cls.__new__(cls, tensors, **kwargs)
return inst
else:
tensor_lengths = [len(o) for o in objects]
cum_tensor_lengths = [sum(tensor_lengths[:i])
for i in range(len(objects))]
templates = [
tfields.TensorFields(
obj,
np.empty((len(obj), 0)),
np.arange(tensor_lengths[i]) + cum_tensor_lengths[i])
for i, obj in enumerate(objects)]
return cls.__new__(cls, tensors, **kwargs), templates
return inst, templates
@classmethod
def grid(cls, *base_vectors, **kwargs):
......@@ -1447,10 +1448,11 @@ class Tensors(AbstractNdarray):
projected_field[nan_mask] = np.nan # correction for nan
fields.append(projected_field)
if dim(template) == 0:
tensors = np.array(self)[template.fields[0]]
# for speed circumvent __getitem__ of the complexer subclasses
tensors = Tensors(self)[template.fields[0]]
else:
tensors = template
return type(self)(Tensors(tensors), *fields)
return type(self)(tensors, *fields)
def cut(self, expression, coord_sys=None, return_template=False, **kwargs):
"""
......@@ -1943,7 +1945,6 @@ class TensorFields(Tensors):
else:
inst, templates = (return_value, None)
fields = []
if all([len(obj.fields) == len(objects[0].fields) for obj in objects]):
for fld_idx in range(len(objects[0].fields)):
......@@ -2104,55 +2105,53 @@ class Container(Fields):
class Maps(sortedcontainers.SortedDict, AbstractObject):
"""
A Maps object is a container for TensorFields sorted by dimension.
Indexing by dimension
Container for TensorFields sorted by dimension, i.e indexing by dimension
Args:
*args (
List(TensorFields):
| List(Tuple(int, TensorFields)):
| TensorFields:
| Tuple(Tensors, *Fields)):
TODO: there is more
)
"""
def __init__(self, *args, **kwargs):
if args and args[0] is None:
# None key passed e.g. by copy. We do not change keys here.
args = args[1:]
# convert all possible input arguments to only
if len(args) == 1 and isinstance(args[0],
sortedcontainers.SortedItemsView):
args = tuple([v for k, v in args[0]])
elif len(args) == 1 and isinstance(args[0], list):
new_args = tuple()
for entry in args[0]:
if len(args) == 1 and issubclass(type(args[0]), (list, dict)):
new_args = []
if issubclass(type(args[0]), list):
# Maps([...])
iterator = args[0]
elif issubclass(type(args[0]), dict):
# Maps({}), Maps(Maps(...)) - includes Maps i.e. copy
iterator = args[0].items()
for entry in iterator:
dimension = None
if issubclass(type(entry), tuple):
if np.issubdtype(type(entry[0]), np.integer):
# Maps([(key, value), ...])
new_args += (entry[1],)
# Maps([(key, value), ...]), Maps({key: value, ...})
mp = self.to_map(entry[1], copy=True)
dimension = entry[0]
else:
# Maps([(tensors, field1, field2), ...])
new_args += (entry,)
mp = self.to_map(*entry, copy=True)
else:
# Maps([mp, mp, ...])
new_args += (entry,)
args = new_args
elif len(args) == 1 and issubclass(type(args[0]), dict):
# Maps({}), Maps(Maps({})) - includes Maps i.e. copy
args = tuple(args[0].values())
elif len(args) == 0 and kwargs:
args = tuple(kwargs.values())
kwargs = {}
# By now everything must have been converted to flat args
# Maps(tfields.TensorFields([], dim=3), [[1,2,3]])
arg_tuple_list = []
for i, arg in enumerate(args):
if len(arg) == 2 and isinstance(arg[0], (int, np.integer)):
dimension, mp = arg
mp = self.to_map(mp, copy=True)
elif isinstance(arg, tuple):
mp = self.to_map(*arg, copy=True)
dimension = dim(mp)
else:
mp = self.to_map(arg, copy=True)
dimension = dim(mp)
arg_tuple_list.append((dimension, mp))
mp = self.to_map(entry, copy=True)
if dimension is None:
dimension = dim(mp)
new_args.append((dimension, mp))
super().__init__(arg_tuple_list, **kwargs)
args = (new_args,)
super().__init__(*args, **kwargs)
@staticmethod
def to_map(mp, *fields, copy=False, **kwargs):
......@@ -2170,29 +2169,16 @@ class Maps(sortedcontainers.SortedDict, AbstractObject):
else:
copy = True
if copy: # not else, because in case of wrong mp type we initialize
mp = TensorFields(mp, *fields, dtype=int, **kwargs)
if not mp.rank == 1:
raise ValueError(
"Incorrect map rank {mp.rank}".format(**locals())
)
kwargs.setdefault('dtype', int)
mp = TensorFields(mp, *fields, **kwargs)
return mp
def __setitem__(self, dimension, mp):
mp = self.to_map(mp)
if not dimension == mp.dim:
raise KeyError(
"Incorrect map dimension {mp.dim} for index {dim}"
.format(**locals())
)
super().__setitem__(dimension, mp)
def __getitem__(self, dimension):
if dimension == 0:
warnings.warn("Using map dimension 0")
return super().__getitem__(dimension)
def _args(self):
return super()._args() + tuple(self.values())
return super()._args() + ([(k, v) for k, v in self.items()],)
def equal(self, other, **kwargs):
"""
......@@ -2241,14 +2227,6 @@ class TensorMaps(TensorFields):
>>> mesh_cp_cyl = tfields.TensorMaps(mesh_copy)
>>> assert mesh_cp_cyl.coord_sys == tfields.bases.CYLINDER
Raises:
>>> import tfields
>>> tfields.TensorMaps([1] * 4,
... dim=3, maps=[[1, 2, 3]]) # +doctest: ELLIPSIS
Traceback (most recent call last):
...
ValueError: Incorrect map rank 0
"""
__slots__ = ['coord_sys', 'name', 'fields', 'maps']
__slot_setters__ = [tfields.bases.get_coord_system_name,
......@@ -2396,7 +2374,7 @@ class TensorMaps(TensorFields):
if return_templates:
mp, dimension_map_templates = return_value
for i in range(len(objects)):
template_maps_list[i].append(dimension_map_templates[i])
template_maps_list[i].append((dimension, dimension_map_templates[i]))
else:
mp = return_value
maps.append(mp)
......@@ -2406,7 +2384,9 @@ class TensorMaps(TensorFields):
for i, template_maps in enumerate(template_maps_list):
templates[i] = tfields.TensorMaps(
templates[i],
maps=template_maps)
maps=Maps(template_maps)) # template maps will not have
# dimensions according to their
# tensors which are indices
return inst, templates
else:
return inst
......@@ -2446,23 +2426,11 @@ class TensorMaps(TensorFields):
"""
inst = super()._cut_template(template) # this will set maps=Maps({})
# # Redirect maps and their fields
# maps = []
# for mp, template_mp in zip(self.maps.values(),
# template.maps.values()):
# mp_fields = []
# for field in mp.fields:
# if len(template_mp) == 0 and len(template_mp.fields) == 0:
# mp_fields.append(field[0:0]) # np.empty
# else:
# mp_fields.append(field[template_mp.fields[0].astype(int)])
# new_mp = tfields.TensorFields(tfields.Tensors(template_mp),
# *mp_fields)
# maps.append(new_mp)
# Redirect maps and their fields
if template.fields:
# bulk was cut so we need to correct the map references.
index_lut = np.full(len(self), np.nan)
index_lut = np.full(len(self), np.nan) # float type
index_lut[template.fields[0]] = np.arange(len(template.fields[0]))
for mp_dim, mp in self.maps.items():
mp = mp._cut_template(template.maps[mp_dim])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment