Commit 20a3adcb authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

rank 1 indices

parent 1e6ced88
......@@ -934,11 +934,12 @@ class Tensors(AbstractNdarray):
"""
return any(self.equal(other, return_bool=False).all(1))
def indices(self, tensor, rtol=None, atol=None, early_stopping=False):
def indices(self, tensor, rtol=None, atol=None):
"""
Returns:
list of int: indices of tensor occuring
Examples:
Rank 1 Tensors
>>> import tfields
>>> p = tfields.Tensors([[1,2,3], [4,5,6], [6,7,8], [4,5,6],
... [4.1, 5, 6]])
......@@ -947,6 +948,13 @@ class Tensors(AbstractNdarray):
>>> p.indices([4,5,6.1], rtol=1e-5, atol=1e-1)
array([1, 3, 4])
Rank 0 Tensors
>>> p = tfields.Tensors([2, 3, 6, 3.01])
>>> p.indices(3)
array([1])
>>> p.indices(3, rtol=1e-5, atol=1e-1)
array([1, 3])
"""
x, y = np.asarray(self), np.asarray(tensor)
if rtol is None and atol is None:
......@@ -955,16 +963,13 @@ class Tensors(AbstractNdarray):
equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol)
# inspired by https://stackoverflow.com/questions/19228295/find-ordered-vector-in-numpy-array
if self.rank == 0:
indices = np.where(equal_method((x-y), 0))[0]
elif self.rank == 1:
indices = np.where(np.all(equal_method((x-y), 0), axis=1))[0]
else:
raise NotImplementedError()
return indices
# old manual method
# indices = []
# for i, p in enumerate(x):
# if equal_method(p, y).all():
# indices.append(i)
# if early_stopping:
# break
# return indices
def index(self, tensor, **kwargs):
"""
......@@ -972,12 +977,8 @@ class Tensors(AbstractNdarray):
tensor
Returns:
int: index of tensor occuring
Raises:
ValueError: Multiple occurences
use early_stopping=True if first entry should be returned
"""
indices = self.indices(tensor, **kwargs)
print(indices)
if not indices:
return None
if len(indices) == 1:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment