From 071c85db9e4fd96fde41f1587781b799b7d7ffb6 Mon Sep 17 00:00:00 2001
From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de>
Date: Fri, 23 Nov 2018 09:38:38 +0100
Subject: [PATCH] passing rtol to base_vectors in grdi

---
 tfields/core.py     | 21 ++++++++++++++++-----
 tfields/lib/grid.py | 15 +++++++++++++--
 2 files changed, 29 insertions(+), 7 deletions(-)

diff --git a/tfields/core.py b/tfields/core.py
index ce7a441..92e8d7d 100644
--- a/tfields/core.py
+++ b/tfields/core.py
@@ -934,25 +934,36 @@ class Tensors(AbstractNdarray):
         """
         return any(self.equal(other, return_bool=False).all(1))
 
-    def indices(self, tensor):
+    def indices(self, tensor, rtol=None, atol=None, early_stopping=False):
         """
         Returns:
             list of int: indices of tensor occuring
         """
+        x, y = np.asarray(self), np.asarray(tensor)
+        if rtol is None and atol is None:
+            equal_method = np.array_equal
+        else:
+            equal_method = lambda a, b: np.isclose(a, b, rtol=rtol, atol=atol)
         indices = []
-        for i, p in enumerate(self):
-            if all(p == tensor):
+        for i, p in enumerate(x):
+            if equal_method(p, y).all():
                 indices.append(i)
+                if early_stopping:
+                    break
         return indices
 
-    def index(self, tensor):
+    def index(self, tensor, **kwargs):
         """
         Args:
             tensor
         Returns:
             int: index of tensor occuring
+        Raises:
+            ValueError: Multiple occurences
+                use early_stopping=True if first entry should be returned
         """
-        indices = self.indices(tensor)
+        indices = self.indices(tensor, **kwargs)
+        print(indices)
         if not indices:
             return None
         if len(indices) == 1:
diff --git a/tfields/lib/grid.py b/tfields/lib/grid.py
index e3ff95e..d9f886d 100644
--- a/tfields/lib/grid.py
+++ b/tfields/lib/grid.py
@@ -123,7 +123,7 @@ def igrid(*base_vectors, **kwargs):
     return obj
 
 
-def base_vectors(array):
+def base_vectors(array, rtol=None, atol=None):
     """
     describe the array in terms of base vectors
     Inverse function of igrid
@@ -141,6 +141,17 @@ def base_vectors(array):
     """
     if len(array.shape) == 1:
         values = sorted(set(array))
+        if rtol is not None and atol is not None:
+            duplicates = set([])
+            for v1, v2 in tfields.lib.util.pairwise(values):
+                if np.isclose(v1, v2, rtol=rtol, atol=atol):
+                    duplicates.add(v2)
+            values = values.difference(duplicates)
+            # round to given absolute precision
+            n_digits = int(abs(np.log10(atol))) + 1
+            values = {round(v, n_digits) for v in values}
+        elif rtol is not None or atol is not None:
+            raise ValueError("rtol and atol arguments only come in pairs.")
         spacing = complex(0, len(values))
         vmin = min(values)
         vmax = max(values)
@@ -148,7 +159,7 @@ def base_vectors(array):
     elif len(array.shape) == 2:
         bases = []
         for i in range(array.shape[1]):
-            bases.append(base_vectors(array[:, i]))
+            bases.append(base_vectors(array[:, i], rtol=rtol, atol=atol))
         return bases
     else:
         raise NotImplementedError("Description yet only till rank 1")
-- 
GitLab