Commit 071c85db authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

passing rtol to base_vectors in grdi

parent 8bf5862c
......@@ -934,25 +934,36 @@ class Tensors(AbstractNdarray):
"""
return any(self.equal(other, return_bool=False).all(1))
def indices(self, tensor):
def indices(self, tensor, rtol=None, atol=None, early_stopping=False):
"""
Returns:
list of int: indices of tensor occuring
"""
x, y = np.asarray(self), np.asarray(tensor)
if rtol is None and atol is None:
equal_method = np.array_equal
else:
equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol)
indices = []
for i, p in enumerate(self):
if all(p == tensor):
for i, p in enumerate(x):
if equal_method(p, y).all():
indices.append(i)
if early_stopping:
break
return indices
def index(self, tensor):
def index(self, tensor, **kwargs):
"""
Args:
tensor
Returns:
int: index of tensor occuring
Raises:
ValueError: Multiple occurences
use early_stopping=True if first entry should be returned
"""
indices = self.indices(tensor)
indices = self.indices(tensor, **kwargs)
print(indices)
if not indices:
return None
if len(indices) == 1:
......
......@@ -123,7 +123,7 @@ def igrid(*base_vectors, **kwargs):
return obj
def base_vectors(array):
def base_vectors(array, rtol=None, atol=None):
"""
describe the array in terms of base vectors
Inverse function of igrid
......@@ -141,6 +141,17 @@ def base_vectors(array):
"""
if len(array.shape) == 1:
values = sorted(set(array))
if rtol is not None and atol is not None:
duplicates = set([])
for v1, v2 in tfields.lib.util.pairwise(values):
if np.isclose(v1, v2, rtol=rtol, atol=atol):
duplicates.add(v2)
values = values.difference(duplicates)
# round to given absolute precision
n_digits = int(abs(np.log10(atol))) + 1
values = {round(v, n_digits) for v in values}
elif rtol is not None or atol is not None:
raise ValueError("rtol and atol arguments only come in pairs.")
spacing = complex(0, len(values))
vmin = min(values)
vmax = max(values)
......@@ -148,7 +159,7 @@ def base_vectors(array):
elif len(array.shape) == 2:
bases = []
for i in range(array.shape[1]):
bases.append(base_vectors(array[:, i]))
bases.append(base_vectors(array[:, i], rtol=rtol, atol=atol))
return bases
else:
raise NotImplementedError("Description yet only till rank 1")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment