Commit 1e6ced88 authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

new indices method

parent 071c85db
......@@ -938,19 +938,33 @@ class Tensors(AbstractNdarray):
"""
Returns:
list of int: indices of tensor occuring
Examples:
>>> import tfields
>>> p = tfields.Tensors([[1,2,3], [4,5,6], [6,7,8], [4,5,6],
... [4.1, 5, 6]])
>>> p.indices([4,5,6])
array([1, 3])
>>> p.indices([4,5,6.1], rtol=1e-5, atol=1e-1)
array([1, 3, 4])
"""
x, y = np.asarray(self), np.asarray(tensor)
if rtol is None and atol is None:
equal_method = np.array_equal
equal_method = np.equal
else:
equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol)
indices = []
for i, p in enumerate(x):
if equal_method(p, y).all():
indices.append(i)
if early_stopping:
break
# inspired by https://stackoverflow.com/questions/19228295/find-ordered-vector-in-numpy-array
indices = np.where(np.all(equal_method((x-y), 0), axis=1))[0]
return indices
# old manual method
# indices = []
# for i, p in enumerate(x):
# if equal_method(p, y).all():
# indices.append(i)
# if early_stopping:
# break
# return indices
def index(self, tensor, **kwargs):
"""
......
import numpy as np
import functools
import tfields.lib.util
def ensure_complex(*base_vectors):
......@@ -140,10 +141,10 @@ def base_vectors(array, rtol=None, atol=None):
"""
if len(array.shape) == 1:
values = sorted(set(array))
values = set(array)
if rtol is not None and atol is not None:
duplicates = set([])
for v1, v2 in tfields.lib.util.pairwise(values):
for v1, v2 in tfields.lib.util.pairwise(sorted(values)):
if np.isclose(v1, v2, rtol=rtol, atol=atol):
duplicates.add(v2)
values = values.difference(duplicates)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment