Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Daniel Boeckenhoff
tfields
Commits
20a3adcb
Commit
20a3adcb
authored
Nov 23, 2018
by
Daniel Boeckenhoff
Browse files
rank 1 indices
parent
1e6ced88
Changes
1
Hide whitespace changes
Inline
Side-by-side
tfields/core.py
View file @
20a3adcb
...
...
@@ -934,11 +934,12 @@ class Tensors(AbstractNdarray):
"""
return
any
(
self
.
equal
(
other
,
return_bool
=
False
).
all
(
1
))
def
indices
(
self
,
tensor
,
rtol
=
None
,
atol
=
None
,
early_stopping
=
False
):
def
indices
(
self
,
tensor
,
rtol
=
None
,
atol
=
None
):
"""
Returns:
list of int: indices of tensor occuring
Examples:
Rank 1 Tensors
>>> import tfields
>>> p = tfields.Tensors([[1,2,3], [4,5,6], [6,7,8], [4,5,6],
... [4.1, 5, 6]])
...
...
@@ -947,6 +948,13 @@ class Tensors(AbstractNdarray):
>>> p.indices([4,5,6.1], rtol=1e-5, atol=1e-1)
array([1, 3, 4])
Rank 0 Tensors
>>> p = tfields.Tensors([2, 3, 6, 3.01])
>>> p.indices(3)
array([1])
>>> p.indices(3, rtol=1e-5, atol=1e-1)
array([1, 3])
"""
x
,
y
=
np
.
asarray
(
self
),
np
.
asarray
(
tensor
)
if
rtol
is
None
and
atol
is
None
:
...
...
@@ -955,16 +963,13 @@ class Tensors(AbstractNdarray):
equal_method
=
lambda
a
,
b
:
np
.
isclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
# inspired by https://stackoverflow.com/questions/19228295/find-ordered-vector-in-numpy-array
indices
=
np
.
where
(
np
.
all
(
equal_method
((
x
-
y
),
0
),
axis
=
1
))[
0
]
if
self
.
rank
==
0
:
indices
=
np
.
where
(
equal_method
((
x
-
y
),
0
))[
0
]
elif
self
.
rank
==
1
:
indices
=
np
.
where
(
np
.
all
(
equal_method
((
x
-
y
),
0
),
axis
=
1
))[
0
]
else
:
raise
NotImplementedError
()
return
indices
# old manual method
# indices = []
# for i, p in enumerate(x):
# if equal_method(p, y).all():
# indices.append(i)
# if early_stopping:
# break
# return indices
def
index
(
self
,
tensor
,
**
kwargs
):
"""
...
...
@@ -972,12 +977,8 @@ class Tensors(AbstractNdarray):
tensor
Returns:
int: index of tensor occuring
Raises:
ValueError: Multiple occurences
use early_stopping=True if first entry should be returned
"""
indices
=
self
.
indices
(
tensor
,
**
kwargs
)
print
(
indices
)
if
not
indices
:
return
None
if
len
(
indices
)
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment