Commit e92ca9ba authored by dboe's avatar dboe
Browse files

getitem works on tensor_grid now:w

parent 6ed6e427
Pipeline #106435 passed with stages
in 1 minute and 12 seconds
...@@ -57,6 +57,11 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check): ...@@ -57,6 +57,11 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check):
tge = tg.explicit() tge = tg.explicit()
self.check_filled(tge) self.check_filled(tge)
def test_getitem(self):
tg = TensorGrid.empty(*self.bv, iter_order=self.iter_order)
self.check_filled(tg[:])
self.check_filled(tg.explicit()[:])
class TensorGrid_Test_Permutation1(TensorGrid_Test): class TensorGrid_Test_Permutation1(TensorGrid_Test):
def setUp(self): def setUp(self):
......
...@@ -14,6 +14,7 @@ def pairwise(iterable): ...@@ -14,6 +14,7 @@ def pairwise(iterable):
Returns: Returns:
two iterators, one ahead of the other two iterators, one ahead of the other
""" """
# pylint:disable=invalid-name
a, b = itertools.tee(iterable) a, b = itertools.tee(iterable)
next(b, None) next(b, None)
return zip(a, b) return zip(a, b)
...@@ -47,9 +48,13 @@ def flatten(seq, container=None, keep_types=None): ...@@ -47,9 +48,13 @@ def flatten(seq, container=None, keep_types=None):
keep_types = [] keep_types = []
if container is None: if container is None:
container = [] container = []
# pylint:disable=invalid-name
for s in seq: for s in seq:
if hasattr(s, '__iter__') and not isinstance(s, string_types) \ if (
and not any([isinstance(s, t) for t in keep_types]): hasattr(s, "__iter__")
and not isinstance(s, string_types)
and not any((isinstance(s, t) for t in keep_types))
):
flatten(s, container, keep_types) flatten(s, container, keep_types)
else: else:
container.append(s) container.append(s)
...@@ -96,61 +101,63 @@ def multi_sort(array, *others, **kwargs): ...@@ -96,61 +101,63 @@ def multi_sort(array, *others, **kwargs):
((), (), ()) ((), (), ())
""" """
method = kwargs.pop('method', None) method = kwargs.pop("method", None)
cast_type = kwargs.pop('cast_type', list) cast_type = kwargs.pop("cast_type", list)
if len(array) == 0: if len(array) == 0:
return tuple(cast_type(x) for x in [array] + list(others)) return tuple(cast_type(x) for x in [array] + list(others))
if method is None: if method is None:
method = sorted method = sorted
if 'key' not in kwargs: if "key" not in kwargs:
kwargs['key'] = lambda pair: pair[0] kwargs["key"] = lambda pair: pair[0]
reverse = kwargs.pop('reverse', False) reverse = kwargs.pop("reverse", False)
if reverse: if reverse:
cast_type = lambda x: list(reversed(x)) # NOQA cast_type = lambda x: list(reversed(x)) # NOQA
return tuple(cast_type(x) for x in zip(*method(zip(array, *others), **kwargs))) return tuple(cast_type(x) for x in zip(*method(zip(array, *others), **kwargs)))
def convert_nan(ar, value=0.): def convert_nan(arr, value=0.0):
""" """
Replace all occuring NaN values by value Replace all occuring NaN values by value
""" """
nanIndices = np.isnan(ar) nan_indices = np.isnan(arr)
ar[nanIndices] = value arr[nan_indices] = value
def view1D(ar): def view_1d(arr):
""" """
Delete duplicate columns of the input array Delete duplicate columns of the input array
https://stackoverflow.com/a/44999009/ @Divakar https://stackoverflow.com/a/44999009/ @Divakar
""" """
ar = np.ascontiguousarray(ar) arr = np.ascontiguousarray(arr)
voidDt = np.dtype((np.void, ar.dtype.itemsize * ar.shape[1])) coid_dt = np.dtype((np.void, arr.dtype.itemsize * arr.shape[1]))
return ar.view(voidDt).ravel() return arr.view(coid_dt).ravel()
def argsort_unique(idx): def argsort_unique(idx):
""" """
https://stackoverflow.com/a/43411559/ @Divakar https://stackoverflow.com/a/43411559/ @Divakar
""" """
n = idx.size num = idx.size
sidx = np.empty(n, dtype=int) sidx = np.empty(num, dtype=int)
sidx[idx] = np.arange(n) sidx[idx] = np.arange(num)
return sidx return sidx
def duplicates(ar, axis=None): def duplicates(arr, axis=None):
""" """
View1D version of duplicate search View1D version of duplicate search
Speed up version after Speed up version after
https://stackoverflow.com/questions/46284660 \ https://stackoverflow.com/questions/46284660 \
/python-numpy-speed-up-2d-duplicate-search/46294916#46294916 /python-numpy-speed-up-2d-duplicate-search/46294916#46294916
Args: Args:
ar (array_like): array arr (array_like): array
other args: see np.isclose other args: see np.isclose
Examples: Examples:
>>> import tfields >>> import tfields
>>> import numpy as np >>> import numpy as np
...@@ -164,25 +171,26 @@ def duplicates(ar, axis=None): ...@@ -164,25 +171,26 @@ def duplicates(ar, axis=None):
Returns: Returns:
list of int: int is pointing to first occurence of unique value list of int: int is pointing to first occurence of unique value
""" """
if len(ar) == 0: if len(arr) == 0:
return np.array([]) return np.array([])
if axis != 0: if axis != 0:
raise NotImplementedError() raise NotImplementedError()
sidx = np.lexsort(ar.T) sidx = np.lexsort(arr.T)
b = ar[sidx] sorted_ = arr[sidx]
groupIndex0 = np.flatnonzero((b[1:] != b[:-1]).any(1)) + 1 group_index_0 = np.flatnonzero((sorted_[1:] != sorted_[:-1]).any(1)) + 1
groupIndex = np.concatenate(([0], groupIndex0, [b.shape[0]])) group_index = np.concatenate(([0], group_index_0, [sorted_.shape[0]]))
ids = np.repeat(range(len(groupIndex) - 1), np.diff(groupIndex)) ids = np.repeat(range(len(group_index) - 1), np.diff(group_index))
sidx_mapped = argsort_unique(sidx) sidx_mapped = argsort_unique(sidx)
ids_mapped = ids[sidx_mapped] ids_mapped = ids[sidx_mapped]
grp_minidx = sidx[groupIndex[:-1]] grp_minidx = sidx[group_index[:-1]]
out = grp_minidx[ids_mapped] out = grp_minidx[ids_mapped]
return out return out
def index(ar, entry, rtol=0, atol=0, equal_nan=False, axis=None): # pylint:disable=too-many-arguments
def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None):
""" """
Examples: Examples:
>>> import tfields >>> import tfields
...@@ -198,18 +206,18 @@ def index(ar, entry, rtol=0, atol=0, equal_nan=False, axis=None): ...@@ -198,18 +206,18 @@ def index(ar, entry, rtol=0, atol=0, equal_nan=False, axis=None):
list of int: indices of point occuring list of int: indices of point occuring
""" """
if axis is None: if axis is None:
ar = ar.flatten() arr = arr.flatten()
elif axis != 0: elif axis != 0:
raise NotImplementedError() raise NotImplementedError()
for i, part in enumerate(ar): for i, part in enumerate(arr):
isclose = np.isclose(part, entry, rtol=rtol, atol=atol, isclose = np.isclose(part, entry, rtol=rtol, atol=atol, equal_nan=equal_nan)
equal_nan=equal_nan)
if axis is not None: if axis is not None:
isclose = isclose.all() isclose = isclose.all()
if isclose: if isclose:
return i return i
if __name__ == '__main__': if __name__ == "__main__":
import doctest import doctest
doctest.testmod() doctest.testmod()
...@@ -36,6 +36,11 @@ class TensorGrid(TensorFields): ...@@ -36,6 +36,11 @@ class TensorGrid(TensorFields):
obj.iter_order = iter_order obj.iter_order = iter_order
return obj return obj
def __getitem__(self, index):
if not self.is_empty():
return super().__getitem__(index)
return self.explicit().__getitem__(index)
@classmethod @classmethod
def from_base_vectors(cls, *base_vectors, tensors=None, fields=None, **kwargs): def from_base_vectors(cls, *base_vectors, tensors=None, fields=None, **kwargs):
iter_order = kwargs.pop("iter_order", np.arange(len(base_vectors))) iter_order = kwargs.pop("iter_order", np.arange(len(base_vectors)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment