Skip to content
Snippets Groups Projects
Commit e92ca9ba authored by dboe's avatar dboe
Browse files

getitem works on tensor_grid now:w

parent 6ed6e427
No related branches found
No related tags found
No related merge requests found
Pipeline #106435 passed
......@@ -57,6 +57,11 @@ class TensorGrid_Test(unittest.TestCase, TensorFields_Check):
tge = tg.explicit()
self.check_filled(tge)
def test_getitem(self):
tg = TensorGrid.empty(*self.bv, iter_order=self.iter_order)
self.check_filled(tg[:])
self.check_filled(tg.explicit()[:])
class TensorGrid_Test_Permutation1(TensorGrid_Test):
def setUp(self):
......
......@@ -14,6 +14,7 @@ def pairwise(iterable):
Returns:
two iterators, one ahead of the other
"""
# pylint:disable=invalid-name
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)
......@@ -47,9 +48,13 @@ def flatten(seq, container=None, keep_types=None):
keep_types = []
if container is None:
container = []
# pylint:disable=invalid-name
for s in seq:
if hasattr(s, '__iter__') and not isinstance(s, string_types) \
and not any([isinstance(s, t) for t in keep_types]):
if (
hasattr(s, "__iter__")
and not isinstance(s, string_types)
and not any((isinstance(s, t) for t in keep_types))
):
flatten(s, container, keep_types)
else:
container.append(s)
......@@ -96,61 +101,63 @@ def multi_sort(array, *others, **kwargs):
((), (), ())
"""
method = kwargs.pop('method', None)
cast_type = kwargs.pop('cast_type', list)
method = kwargs.pop("method", None)
cast_type = kwargs.pop("cast_type", list)
if len(array) == 0:
return tuple(cast_type(x) for x in [array] + list(others))
if method is None:
method = sorted
if 'key' not in kwargs:
kwargs['key'] = lambda pair: pair[0]
if "key" not in kwargs:
kwargs["key"] = lambda pair: pair[0]
reverse = kwargs.pop('reverse', False)
reverse = kwargs.pop("reverse", False)
if reverse:
cast_type = lambda x: list(reversed(x)) # NOQA
return tuple(cast_type(x) for x in zip(*method(zip(array, *others), **kwargs)))
def convert_nan(ar, value=0.):
def convert_nan(arr, value=0.0):
"""
Replace all occuring NaN values by value
"""
nanIndices = np.isnan(ar)
ar[nanIndices] = value
nan_indices = np.isnan(arr)
arr[nan_indices] = value
def view1D(ar):
def view_1d(arr):
"""
Delete duplicate columns of the input array
https://stackoverflow.com/a/44999009/ @Divakar
"""
ar = np.ascontiguousarray(ar)
voidDt = np.dtype((np.void, ar.dtype.itemsize * ar.shape[1]))
return ar.view(voidDt).ravel()
arr = np.ascontiguousarray(arr)
coid_dt = np.dtype((np.void, arr.dtype.itemsize * arr.shape[1]))
return arr.view(coid_dt).ravel()
def argsort_unique(idx):
"""
https://stackoverflow.com/a/43411559/ @Divakar
"""
n = idx.size
sidx = np.empty(n, dtype=int)
sidx[idx] = np.arange(n)
num = idx.size
sidx = np.empty(num, dtype=int)
sidx[idx] = np.arange(num)
return sidx
def duplicates(ar, axis=None):
def duplicates(arr, axis=None):
"""
View1D version of duplicate search
Speed up version after
https://stackoverflow.com/questions/46284660 \
/python-numpy-speed-up-2d-duplicate-search/46294916#46294916
Args:
ar (array_like): array
arr (array_like): array
other args: see np.isclose
Examples:
>>> import tfields
>>> import numpy as np
......@@ -164,25 +171,26 @@ def duplicates(ar, axis=None):
Returns:
list of int: int is pointing to first occurence of unique value
"""
if len(ar) == 0:
if len(arr) == 0:
return np.array([])
if axis != 0:
raise NotImplementedError()
sidx = np.lexsort(ar.T)
b = ar[sidx]
sidx = np.lexsort(arr.T)
sorted_ = arr[sidx]
groupIndex0 = np.flatnonzero((b[1:] != b[:-1]).any(1)) + 1
groupIndex = np.concatenate(([0], groupIndex0, [b.shape[0]]))
ids = np.repeat(range(len(groupIndex) - 1), np.diff(groupIndex))
group_index_0 = np.flatnonzero((sorted_[1:] != sorted_[:-1]).any(1)) + 1
group_index = np.concatenate(([0], group_index_0, [sorted_.shape[0]]))
ids = np.repeat(range(len(group_index) - 1), np.diff(group_index))
sidx_mapped = argsort_unique(sidx)
ids_mapped = ids[sidx_mapped]
grp_minidx = sidx[groupIndex[:-1]]
grp_minidx = sidx[group_index[:-1]]
out = grp_minidx[ids_mapped]
return out
def index(ar, entry, rtol=0, atol=0, equal_nan=False, axis=None):
# pylint:disable=too-many-arguments
def index(arr, entry, rtol=0, atol=0, equal_nan=False, axis=None):
"""
Examples:
>>> import tfields
......@@ -198,18 +206,18 @@ def index(ar, entry, rtol=0, atol=0, equal_nan=False, axis=None):
list of int: indices of point occuring
"""
if axis is None:
ar = ar.flatten()
arr = arr.flatten()
elif axis != 0:
raise NotImplementedError()
for i, part in enumerate(ar):
isclose = np.isclose(part, entry, rtol=rtol, atol=atol,
equal_nan=equal_nan)
for i, part in enumerate(arr):
isclose = np.isclose(part, entry, rtol=rtol, atol=atol, equal_nan=equal_nan)
if axis is not None:
isclose = isclose.all()
if isclose:
return i
if __name__ == '__main__':
if __name__ == "__main__":
import doctest
doctest.testmod()
......@@ -36,6 +36,11 @@ class TensorGrid(TensorFields):
obj.iter_order = iter_order
return obj
def __getitem__(self, index):
if not self.is_empty():
return super().__getitem__(index)
return self.explicit().__getitem__(index)
@classmethod
def from_base_vectors(cls, *base_vectors, tensors=None, fields=None, **kwargs):
iter_order = kwargs.pop("iter_order", np.arange(len(base_vectors)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment