Commit 071c85db authored by Daniel Boeckenhoff's avatar Daniel Boeckenhoff
Browse files

passing rtol to base_vectors in grdi

parent 8bf5862c
...@@ -934,25 +934,36 @@ class Tensors(AbstractNdarray): ...@@ -934,25 +934,36 @@ class Tensors(AbstractNdarray):
""" """
return any(self.equal(other, return_bool=False).all(1)) return any(self.equal(other, return_bool=False).all(1))
def indices(self, tensor): def indices(self, tensor, rtol=None, atol=None, early_stopping=False):
""" """
Returns: Returns:
list of int: indices of tensor occuring list of int: indices of tensor occuring
""" """
x, y = np.asarray(self), np.asarray(tensor)
if rtol is None and atol is None:
equal_method = np.array_equal
else:
equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol)
indices = [] indices = []
for i, p in enumerate(self): for i, p in enumerate(x):
if all(p == tensor): if equal_method(p, y).all():
indices.append(i) indices.append(i)
if early_stopping:
break
return indices return indices
def index(self, tensor): def index(self, tensor, **kwargs):
""" """
Args: Args:
tensor tensor
Returns: Returns:
int: index of tensor occuring int: index of tensor occuring
Raises:
ValueError: Multiple occurences
use early_stopping=True if first entry should be returned
""" """
indices = self.indices(tensor) indices = self.indices(tensor, **kwargs)
print(indices)
if not indices: if not indices:
return None return None
if len(indices) == 1: if len(indices) == 1:
......
...@@ -123,7 +123,7 @@ def igrid(*base_vectors, **kwargs): ...@@ -123,7 +123,7 @@ def igrid(*base_vectors, **kwargs):
return obj return obj
def base_vectors(array): def base_vectors(array, rtol=None, atol=None):
""" """
describe the array in terms of base vectors describe the array in terms of base vectors
Inverse function of igrid Inverse function of igrid
...@@ -141,6 +141,17 @@ def base_vectors(array): ...@@ -141,6 +141,17 @@ def base_vectors(array):
""" """
if len(array.shape) == 1: if len(array.shape) == 1:
values = sorted(set(array)) values = sorted(set(array))
if rtol is not None and atol is not None:
duplicates = set([])
for v1, v2 in tfields.lib.util.pairwise(values):
if np.isclose(v1, v2, rtol=rtol, atol=atol):
duplicates.add(v2)
values = values.difference(duplicates)
# round to given absolute precision
n_digits = int(abs(np.log10(atol))) + 1
values = {round(v, n_digits) for v in values}
elif rtol is not None or atol is not None:
raise ValueError("rtol and atol arguments only come in pairs.")
spacing = complex(0, len(values)) spacing = complex(0, len(values))
vmin = min(values) vmin = min(values)
vmax = max(values) vmax = max(values)
...@@ -148,7 +159,7 @@ def base_vectors(array): ...@@ -148,7 +159,7 @@ def base_vectors(array):
elif len(array.shape) == 2: elif len(array.shape) == 2:
bases = [] bases = []
for i in range(array.shape[1]): for i in range(array.shape[1]):
bases.append(base_vectors(array[:, i])) bases.append(base_vectors(array[:, i], rtol=rtol, atol=atol))
return bases return bases
else: else:
raise NotImplementedError("Description yet only till rank 1") raise NotImplementedError("Description yet only till rank 1")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment