diff --git a/tests/test_bounding_box.py b/tests/test_bounding_box.py
index 453a60910bc863d61b1d1bc91bb72b55d3a84d3b..7bbdfa8e24745ceaecb55a16c1114e0de60bf04e 100644
--- a/tests/test_bounding_box.py
+++ b/tests/test_bounding_box.py
@@ -1,23 +1,17 @@
 import unittest
 import tfields
-# import numpy as np
+import numpy as np
 
 
 class BoundingBox_Test(unittest.TestCase):
     def setUp(self):
-        self._mesh = tfields.Mesh3D.grid((5.6, 6.2, 3),
-                                         (-0.25, 0.25, 4),
-                                         (-1, 1, 10))
+        self.mesh = tfields.Mesh3D.grid((5.6, 6.2, 3), (-0.25, 0.25, 4), (-1, 1, 10))
 
-        self._cuts = {'x': [5.7, 6.1],
-                      'y': [-0.2, 0, 0.2],
-                      'z': [-0.5, 0.5]}
+        self.cuts = {"x": [5.7, 6.1], "y": [-0.2, 0, 0.2], "z": [-0.5, 0.5]}
 
     def test_tree(self):
         # test already in doctests.
-        tree = tfields.bounding_box.Node(self._mesh,
-                                         self._cuts,
-                                         at_intersection='keep')
+        tree = tfields.bounding_box.Node(self.mesh, self.cuts, at_intersection="keep")
         leaves = tree.leaves()
         leaves = tfields.bounding_box.Node.sort_leaves(leaves)
         meshes = [leaf.mesh for leaf in leaves]  # NOQA
@@ -25,21 +19,39 @@ class BoundingBox_Test(unittest.TestCase):
         special_leaf = tree.find_leaf([5.65, -0.21, 0])  # NOQA
 
 
-class Searcher_Test(unittest.TestCase):
+class Searcher_Check(object):
     def setUp(self):
-        self._mesh = tfields.Mesh3D.grid((0, 1, 2), (1, 2, 2), (2, 3, 2))
+        self.mesh = None
+        self.points = None
+        self.delta = None
+        self.n_sections = None
 
-    # not yet working again
-    # def test_tree(self):
-    #     tree = tfields.bounding_box.Searcher(self._mesh, n_sections=[5, 5, 5])
-    #     points = tfields.Tensors([[0.5, 1, 2.1],
-    #                               [0.5, 0, 0],
-    #                               [0.5, 2, 2.1],
-    #                               [0.5, 1.5, 2.5]])
-    #     box_res = tree.in_faces(points, delta=0.0001)
-    #     usual_res = self._mesh.in_faces(points, delta=0.0001)
-    #     self.assertTrue(np.array_equal(box_res, usual_res))
+    @property
+    def usual_res(self):
+        if not hasattr(self, "_usual_res"):
+            self._usual_res = self.mesh.in_faces(self.points, delta=self.delta)
+        return self._usual_res
+
+    def test_tree(self):
+        tree = tfields.bounding_box.Searcher(self.mesh, n_sections=None)
+        box_res = tree.in_faces(self.points, delta=self.delta)
+        self.assertTrue(np.array_equal(box_res, self.usual_res))
+
+    def test_tree_n_sections(self):
+        tree = tfields.bounding_box.Searcher(self.mesh, n_sections=self.n_sections)
+        box_res = tree.in_faces(self.points, delta=self.delta)
+        self.assertTrue(np.array_equal(box_res, self.usual_res))
+
+
+class Searcher_Coarse_Grid_2D_Test(Searcher_Check, unittest.TestCase):
+    def setUp(self):
+        self.mesh = tfields.Mesh3D.grid((0, 1, 2), (1, 2, 2), (2, 3, 2))
+        self.points = tfields.Tensors(
+            [[0.5, 1, 2.1], [0.5, 0, 0], [0.5, 2, 2.1], [0.5, 1.5, 2.5]]
+        )
+        self.delta = 0.0001
+        self.n_sections = [5, 5, 5]
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tests/test_mesh3D.py b/tests/test_mesh3D.py
index 23570faececec9846bdfd96b72e8439297d09fbb..bc02bc6f3a16f16579df2635d5eac05572e2db8b 100644
--- a/tests/test_mesh3D.py
+++ b/tests/test_mesh3D.py
@@ -6,15 +6,17 @@ import sympy  # NOQA: F401
 import os
 import sys
 from tests.test_core import TensorMaps_Check
+
 THIS_DIR = os.path.dirname(
-    os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
+    os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))
+)
 sys.path.append(os.path.normpath(os.path.join(THIS_DIR)))
 
 
 class Mesh3D_Check(TensorMaps_Check):
     def test_cut_split(self):
-        x, y, z = sympy.symbols('x y z')
-        self._inst.cut(x + 1./100*y > 0, at_intersection='split')
+        x, y, z = sympy.symbols("x y z")
+        self._inst.cut(x + 1.0 / 100 * y > 0, at_intersection="split")
 
     def test_triangle(self):
         tri = self._inst.triangles()
@@ -24,28 +26,48 @@ class Mesh3D_Check(TensorMaps_Check):
         self.assertTrue(tri.equal(tri_2))
 
     def test_save_obj(self):
-        out_file = NamedTemporaryFile(suffix='.obj')
+        out_file = NamedTemporaryFile(suffix=".obj")
         self._inst.save(out_file.name)
         _ = out_file.seek(0)  # this is only necessary in the test
         load_inst = type(self._inst).load(out_file.name)
-        print(self._inst, load_inst)
         self.demand_equal(load_inst)
 
 
-class Square_Test(Mesh3D_Check, unittest.TestCase):
+class Square_Sparse_Test(Mesh3D_Check, unittest.TestCase):
     def setUp(self):
         self._inst = tfields.Mesh3D.plane((0, 1, 2j), (0, 1, 2j), (0, 0, 1j))
 
+    def test_cut_keep(self):
+        x = sympy.symbols("x")
+        cut_inst = self._inst.cut(x < 0.5, at_intersection="keep")
+        self.assertTrue(self._inst.equal(cut_inst))
+        cut_inst = self._inst.cut(x > 0.5, at_intersection="keep")
+        self.assertTrue(self._inst.equal(cut_inst))
+
+
+class Square_Dense_Test(Mesh3D_Check, unittest.TestCase):
+    def setUp(self):
+        self._inst = tfields.Mesh3D.plane((0, 1, 5j), (0, 1, 5j), (0, 0, 1j))
+
+    def test_cut_keep(self):
+        x = sympy.symbols("x")
+        cut_inst = self._inst.cut(x < 0.9, at_intersection="keep")
+        self.assertTrue(self._inst.equal(cut_inst))
+        self.assertEqual(len(cut_inst), 25)
+        cut_inst = self._inst.cut(x > 0.9, at_intersection="keep")
+        self.assertEqual(len(cut_inst), 10)
+
 
 class Sphere_Test(Mesh3D_Check, unittest.TestCase):
     def setUp(self):
         basis_points = 4
         self._inst = tfields.Mesh3D.grid(
-                (1, 1, 1),
-                (-np.pi, np.pi, basis_points),
-                (-np.pi / 2, np.pi / 2, basis_points),
-                coord_sys='spherical')
-        self._inst.transform('cartesian')
+            (1, 1, 1),
+            (-np.pi, np.pi, basis_points),
+            (-np.pi / 2, np.pi / 2, basis_points),
+            coord_sys="spherical",
+        )
+        self._inst.transform("cartesian")
         self._inst[:, 1] += 2
         clean = self._inst.cleaned()
         # self.demand_equal(clean)
@@ -54,9 +76,17 @@ class Sphere_Test(Mesh3D_Check, unittest.TestCase):
 
 class IO_Stl_test(unittest.TestCase):  # no Mesh3D_Check for speed
     def setUp(self):
-        self._inst = tfields.Mesh3D.load(os.path.join(THIS_DIR,
-                                                      '../data/baffle.stl'))
+        self._inst = tfields.Mesh3D.load(os.path.join(THIS_DIR, "../data/baffle.stl"))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
+    # TODO!
+    import rna
+    from sympy.abc import x
+
+    kwargs = dict(dim=3, edgecolor="k")
+    mesh = tfields.Mesh3D.grid((0, 1, 2), (1, 2, 2), (2, 3, 2))
+    cuts = mesh.cut(x < 0.2, at_intersection="split")
+    cuts.plot(**kwargs)
+    rna.plotting.show()
diff --git a/tfields/bounding_box.py b/tfields/bounding_box.py
index 83bccfb0517bae46484ccc209cc9d35fbb90bb25..b3d0145d31ce7b4f8e2f9817c7e46eb803221670 100644
--- a/tfields/bounding_box.py
+++ b/tfields/bounding_box.py
@@ -17,6 +17,14 @@ class Node(object):
         cut_expr: Cut that determines the seperation in left and right node
         cuts: List of cuts for the children nodes
 
+    Attrs:
+        parent (Node)
+        remaining_cuts (dict): key specifies dimension, value the cuts that
+            are still not done
+        cut_expr (dict): part of parents remaining_cuts. The dimension defines
+            what is meant by left and right
+
+
     Examples:
         >>> import tfields
         >>> mesh = tfields.Mesh3D.grid((5.6, 6.2, 3),
@@ -37,27 +45,35 @@ class Node(object):
         >>> special_leaf = tree.find_leaf([5.65, -0.21, 0])
 
     """
-    def __init__(self, mesh, cuts,
-                 coord_sys=None, at_intersection="split", delta=0.,
-                 parent=None, box=None, internal_template=None, cut_expr=None):
+
+    def __init__(
+        self,
+        mesh,
+        cuts,
+        coord_sys=None,
+        at_intersection="split",
+        delta=0.0,
+        parent=None,
+        box=None,
+        internal_template=None,
+        cut_expr=None,
+    ):
         self.parent = parent
         # initialize
         self.mesh = copy.deepcopy(mesh)
         if self.is_root():
             cuts = copy.deepcopy(cuts)  # dicts are mutable
         self.remaining_cuts = cuts
-        log = logging.getLogger('Node')
-        log.debug(cuts)
+        logging.debug(cuts)
 
         self.delta = delta
         if box is None:
             vertices = np.array(self.mesh)
-            self.box = {'x': [min(vertices[:, 0]) - delta,
-                              max(vertices[:, 0]) + delta],
-                        'y': [min(vertices[:, 1]) - delta,
-                              max(vertices[:, 1]) + delta],
-                        'z': [min(vertices[:, 2]) - delta,
-                              max(vertices[:, 2]) + delta]}
+            self.box = {
+                "x": [min(vertices[:, 0]) - delta, max(vertices[:, 0]) + delta],
+                "y": [min(vertices[:, 1]) - delta, max(vertices[:, 1]) + delta],
+                "z": [min(vertices[:, 2]) - delta, max(vertices[:, 2]) + delta],
+            }
         else:
             self.box = box
         self.left = None
@@ -100,7 +116,7 @@ class Node(object):
 
     def in_box(self, point):
         x, y, z = point
-        for key in ['x', 'y', 'z']:
+        for key in ["x", "y", "z"]:
             value = locals()[key]
             if value < self.box[key][0] or self.box[key][1] < value:
                 return False
@@ -117,23 +133,29 @@ class Node(object):
         """
         sorting the leaves first in x, then y, then z direction
         """
-        sorted_leaves = sorted(leaves_list,
-                               key=lambda x: (x.box['x'][1], x.box['y'][1], x.box['z'][1]))
+        sorted_leaves = sorted(
+            leaves_list, key=lambda x: (x.box["x"][1], x.box["y"][1], x.box["z"][1])
+        )
         return sorted_leaves
 
     def _trim_to_box(self):
         # 6 cuts to remove outer part of the box
-        x, y, z = sympy.symbols('x y z')
+        x, y, z = sympy.symbols("x y z")
         eps = 0.0000000001
-        x_cut = (float(self.box['x'][0] - eps) <= x) & (x <= float(self.box['x'][1] + eps))
-        y_cut = (float(self.box['y'][0] - eps) <= y) & (y <= float(self.box['y'][1] + eps))
-        z_cut = (float(self.box['z'][0] - eps) <= z) & (z <= float(self.box['z'][1] + eps))
+        x_cut = (float(self.box["x"][0] - eps) <= x) & (
+            x <= float(self.box["x"][1] + eps)
+        )
+        y_cut = (float(self.box["y"][0] - eps) <= y) & (
+            y <= float(self.box["y"][1] + eps)
+        )
+        z_cut = (float(self.box["z"][0] - eps) <= z) & (
+            z <= float(self.box["z"][1] + eps)
+        )
         section_cut = x_cut & y_cut & z_cut
 
         self.mesh, self._internal_template = self.mesh.cut(
-            section_cut,
-            at_intersection=self.at_intersection,
-            return_template=True)
+            section_cut, at_intersection=self.at_intersection, return_template=True
+        )
 
     def leaves(self):
         """
@@ -182,26 +204,28 @@ class Node(object):
 
     def _split(self):
         """
-        Split the node, if there is no cut_expr set and remaing cuts exist.
+        Split the node in two new nodes, if there is no cut_expr set and
+        remaing cuts exist.
         """
         if self.cut_expr is None and self.remaining_cuts is None:
-            raise RuntimeError("Cannot split the mesh without cut_expr and"
-                               "remaining_cuts")
+            raise RuntimeError(
+                "Cannot split the mesh without cut_expr and" "remaining_cuts"
+            )
         else:
             # create cut expression
-            x, y, z = sympy.symbols('x y z')
-            if 'x' in self.cut_expr:
-                left_cut_expression = x <= self.cut_expr['x']
-                right_cut_expression = x >= self.cut_expr['x']
-                key = 'x'
-            elif 'y' in self.cut_expr:
-                left_cut_expression = y <= self.cut_expr['y']
-                right_cut_expression = y >= self.cut_expr['y']
-                key = 'y'
-            elif 'z' in self.cut_expr:
-                left_cut_expression = z <= self.cut_expr['z']
-                right_cut_expression = z >= self.cut_expr['z']
-                key = 'z'
+            x, y, z = sympy.symbols("x y z")
+            if "x" in self.cut_expr:
+                left_cut_expression = x <= self.cut_expr["x"]
+                right_cut_expression = x >= self.cut_expr["x"]
+                key = "x"
+            elif "y" in self.cut_expr:
+                left_cut_expression = y <= self.cut_expr["y"]
+                right_cut_expression = y >= self.cut_expr["y"]
+                key = "y"
+            elif "z" in self.cut_expr:
+                left_cut_expression = z <= self.cut_expr["z"]
+                right_cut_expression = z >= self.cut_expr["z"]
+                key = "z"
             else:
                 raise KeyError()
 
@@ -227,38 +251,50 @@ class Node(object):
             left_mesh, self.left_template = self.mesh.cut(
                 left_cut_expression,
                 at_intersection=self.at_intersection,
-                return_template=True)
+                return_template=True,
+            )
             right_mesh, self.right_template = self.mesh.cut(
                 right_cut_expression,
                 at_intersection=self.at_intersection,
-                return_template=True)
+                return_template=True,
+            )
 
             # two new Nodes
-            self.left = Node(left_mesh,
-                             left_cuts,
-                             parent=self,
-                             internal_template=self.left_template, cut_expr=None,
-                             coord_sys=self.coord_sys,
-                             at_intersection=self.at_intersection,
-                             box=left_box)
-            self.right = Node(right_mesh,
-                              right_cuts,
-                              parent=self,
-                              internal_template=self.right_template, cut_expr=None,
-                              coord_sys=self.coord_sys,
-                              at_intersection=self.at_intersection,
-                              box=right_box)
+            self.left = Node(
+                left_mesh,
+                left_cuts,
+                parent=self,
+                internal_template=self.left_template,
+                cut_expr=None,
+                coord_sys=self.coord_sys,
+                at_intersection=self.at_intersection,
+                box=left_box,
+            )
+            self.right = Node(
+                right_mesh,
+                right_cuts,
+                parent=self,
+                internal_template=self.right_template,
+                cut_expr=None,
+                coord_sys=self.coord_sys,
+                at_intersection=self.at_intersection,
+                box=right_box,
+            )
 
     def _choose_next_cut(self):
+        """
+        Set self.cut_expr by choosing the dimension with the most remaining
+        cuts. Remove that cut from remaining cuts
+        """
         largest = 0
         for key in self.remaining_cuts:
             if len(self.remaining_cuts[key]) > largest:
                 largest = len(self.remaining_cuts[key])
                 largest_key = key
 
-        median = sorted(
-            self.remaining_cuts[largest_key]
-        )[int(0.5 * (len(self.remaining_cuts[largest_key]) - 1))]
+        median = sorted(self.remaining_cuts[largest_key])[
+            int(0.5 * (len(self.remaining_cuts[largest_key]) - 1))
+        ]
         # pop median cut from remaining cuts
         self.remaining_cuts[largest_key] = [
             x for x in self.remaining_cuts[largest_key] if x != median
@@ -323,22 +359,21 @@ class Node(object):
             if template.fields:
                 for idx in template.fields[0]:
                     template_field.append(self._convert_field_index(idx))
-                template.fields = [tfields.Tensors(template_field, dim=1,
-                                                   dtype=int)]
+                template.fields = [tfields.Tensors(template_field, dim=1, dtype=int)]
 
             template_map_field = []
             if len(template.maps[3]) > 0:
                 for idx in template.maps[3].fields[0]:
                     template_map_field.append(self._convert_map_index(idx))
-                template.maps[3].fields = [tfields.Tensors(template_map_field,
-                                                           dim=1, dtype=int)]
+                template.maps[3].fields = [
+                    tfields.Tensors(template_map_field, dim=1, dtype=int)
+                ]
             self._template = template
         return self._template
 
 
 class Searcher(Node):
-    def __init__(self, mesh, n_sections=None, delta=0.,
-                 cut_length=None):
+    def __init__(self, mesh, n_sections=None, delta=0.0, cut_length=None):
         """
         Special cutting tree root node.
         Provides a fast point in mesh search algorithm (Searcher.in_faces)
@@ -363,8 +398,9 @@ class Searcher(Node):
 
                 ab, ac = triangles.edges()
                 bc = ac - ab
-                side_lengths = np.concatenate([np.linalg.norm(side, axis=1)
-                                               for side in [ab, ac, bc]])
+                side_lengths = np.concatenate(
+                    [np.linalg.norm(side, axis=1) for side in [ab, ac, bc]]
+                )
                 # import mplTools as mpl
                 # axis= tfields.plotting.gca(2)
                 # mpl.plotHistogram(side_lengths, axis=axis)
@@ -384,18 +420,17 @@ class Searcher(Node):
         elif cut_length is not None:
             raise ValueError("cut_length not used.")
 
+        # build dictionary with cuts per dimension
         cut = {}
-        for i, key in enumerate(['x', 'y', 'z']):
+        for i, key in enumerate(["x", "y", "z"]):
             n_cuts = n_sections[i] + 1
-            if n_cuts <= 2:
-                values = []
-            else:
-                values = np.linspace(minima[i], maxima[i], n_cuts)[1:-1]
+            # [1:-1] because no need to cut at min or max
+            values = np.linspace(minima[i], maxima[i], n_cuts)[1:-1]
             cut[key] = values
 
-        return super(Searcher, self).__init__(mesh, cut,
-                                              at_intersection='keep',
-                                              delta=delta)
+        return super(Searcher, self).__init__(
+            mesh, cut, at_intersection="keep", delta=delta
+        )
 
     def in_faces(self, tensors, delta=-1, assign_multiple=False):
         """
@@ -403,25 +438,26 @@ class Searcher(Node):
             * check rare case of point+-delta outside box
 
         Examples:
-            # >>> import tfields
-            # >>> import numpy as np
-            # >>> mesh = tfields.Mesh3D.grid((0, 1, 2), (1, 2, 2), (2, 3, 2))
-            # >>> tree = tfields.bounding_box.Searcher(mesh, n_sections=[5, 5, 5])
-            # >>> points = tfields.Tensors([[0.5, 1, 2.1],
-            # ...                           [0.5, 0, 0],
-            # ...                           [0.5, 2, 2.1],
-            # ...                           [0.5, 1.5, 2.5]])
-            # >>> box_res = tree.in_faces(points, delta=0.0001)
-            # >>> usual_res = mesh.in_faces(points, delta=0.0001)
-            # >>> assert np.array_equal(box_res, usual_res)
+            >>> import tfields
+            >>> import numpy as np
+            >>> mesh = tfields.Mesh3D.grid((0, 1, 2), (1, 2, 2), (2, 3, 2))
+            >>> tree = tfields.bounding_box.Searcher(mesh)
+            >>> points = tfields.Tensors([[0.5, 1, 2.1],
+            ...                           [0.5, 0, 0],
+            ...                           [0.5, 2, 2.1],
+            ...                           [0.5, 1.5, 2.5]])
+            >>> box_res = tree.in_faces(points, delta=0.0001)
+            >>> usual_res = mesh.in_faces(points, delta=0.0001)
+            >>> assert np.array_equal(box_res, usual_res)
 
         """
-        raise ValueError("Broken feature. We are working on it!")
+        # raise ValueError("Broken feature. We are working on it!")
         if not self.is_root():
             raise ValueError("in_faces may only be called by root Node.")
-        if self.at_intersection != 'keep':
-            raise ValueError("Intersection method must be 'keep' for in_faces"
-                             "method.")
+        if self.at_intersection != "keep":
+            raise ValueError(
+                "Intersection method must be 'keep' for in_faces" "method."
+            )
 
         if self.mesh.nfaces() == 0:
             return np.empty((tensors.shape[0], 0), dtype=bool)
@@ -436,8 +472,7 @@ class Searcher(Node):
                     continue
                 if leaf.template.nfaces() == 0:
                     continue
-                leaf_mask = leaf.template.triangles()._in_triangles(point,
-                                                                    delta)
+                leaf_mask = leaf.template.triangles()._in_triangles(point, delta)
                 original_face_indices = leaf.template.maps[3].fields[0][leaf_mask]
                 if not assign_multiple and len(original_face_indices) > 0:
                     original_face_indices = original_face_indices[:1]
@@ -445,7 +480,8 @@ class Searcher(Node):
         return masks
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import doctest
+
     # doctest.run_docstring_examples(Searcher.in_faces, globals())
     doctest.testmod()
diff --git a/tfields/core.py b/tfields/core.py
index b1463b92aa570b08bd0db5a1d63cd8db27618869..92e6caf34981e6e7720eee080d74f158cfc4854f 100644
--- a/tfields/core.py
+++ b/tfields/core.py
@@ -154,7 +154,7 @@ class AbstractObject(object):
 
         """
         content_dict = self._as_dict()
-        content_dict['tfields_version'] = tfields.__version__
+        content_dict["tfields_version"] = tfields.__version__
         np.savez(path, **content_dict)
 
     @classmethod
@@ -165,10 +165,10 @@ class AbstractObject(object):
         """
         # TODO: think about allow_pickle, wheter it really should be True or
         # wheter we could avoid pickling (potential security issue)
-        load_kwargs.setdefault('allow_pickle', True)
+        load_kwargs.setdefault("allow_pickle", True)
         np_file = np.load(path, **load_kwargs)
         d = dict(np_file)
-        d.pop('tfields_version', None)
+        d.pop("tfields_version", None)
         return cls._from_dict(d)
 
     def _args(self) -> tuple:
@@ -177,7 +177,7 @@ class AbstractObject(object):
     def _kwargs(self) -> dict:
         return dict()
 
-    _HIERARCHY_SEPARATOR = '::'
+    _HIERARCHY_SEPARATOR = "::"
 
     def _as_dict(self):
         d = {}
@@ -187,17 +187,15 @@ class AbstractObject(object):
 
         # args and kwargs
         for base_attr, iterable in [
-                ('args', ((str(i), arg)
-                          for i, arg in enumerate(self._args()))),
-                ('kwargs', self._kwargs().items())]:
+            ("args", ((str(i), arg) for i, arg in enumerate(self._args()))),
+            ("kwargs", self._kwargs().items()),
+        ]:
             for attr, value in iterable:
                 attr = base_attr + self._HIERARCHY_SEPARATOR + attr
-                if hasattr(value, '_as_dict'):
+                if hasattr(value, "_as_dict"):
                     part_dict = value._as_dict()
                     for part_attr, part_value in part_dict.items():
-                        d[
-                            attr + self._HIERARCHY_SEPARATOR + part_attr
-                        ] = part_value
+                        d[attr + self._HIERARCHY_SEPARATOR + part_attr] = part_value
                 else:
                     d[attr] = value
         return d
@@ -205,7 +203,7 @@ class AbstractObject(object):
     @classmethod
     def _from_dict(cls, d: dict):
         try:
-            d.pop('type')
+            d.pop("type")
         except KeyError:
             # legacy
             return cls._from_dict_legacy(**d)
@@ -227,7 +225,7 @@ class AbstractObject(object):
         """
         for attr in here:
             for key in here[attr]:
-                if 'type' in here[attr][key]:
+                if "type" in here[attr][key]:
                     obj_type = here[attr][key].get("type")
                     if isinstance(obj_type, np.ndarray):  # happens on np.load
                         obj_type = obj_type.tolist()
@@ -238,15 +236,15 @@ class AbstractObject(object):
                     obj_type = getattr(tfields, obj_type)
                     attr_value = obj_type._from_dict(here[attr][key])
                 else:  # if len(here[attr][key]) == 1:
-                    attr_value = here[attr][key].pop('')
+                    attr_value = here[attr][key].pop("")
                 here[attr][key] = attr_value
 
-        '''
+        """
         Build the generic way
-        '''
-        args = here.pop('args', tuple())
+        """
+        args = here.pop("args", tuple())
         args = tuple(args[key] for key in sorted(args))
-        kwargs = here.pop('kwargs', {})
+        kwargs = here.pop("kwargs", {})
         assert len(here) == 0
         obj = cls(*args, **kwargs)
         return obj
@@ -294,17 +292,17 @@ class AbstractObject(object):
                 bulk_type = getattr(tfields, bulk_type)
                 list_dict[key].append(bulk_type._from_dict_legacy(**sub_dict[index]))
 
-        with cls._bypass_setters('fields', demand_existence=False):
-            '''
+        with cls._bypass_setters("fields", demand_existence=False):
+            """
             Build the normal way
-            '''
-            bulk = kwargs.pop('bulk')
-            bulk_type = kwargs.pop('bulk_type')
+            """
+            bulk = kwargs.pop("bulk")
+            bulk_type = kwargs.pop("bulk_type")
             obj = cls.__new__(cls, bulk, **kwargs)
 
-            '''
+            """
             Set list attributes
-            '''
+            """
             for attr, list_value in list_dict.items():
                 setattr(obj, attr, list_value)
         return obj
@@ -374,9 +372,7 @@ class AbstractNdarray(np.ndarray, AbstractObject):
         slot_dtypes = cls.__slot_dtypes__ + [None] * (
             len(cls.__slots__) - len(cls.__slot_dtypes__)
         )
-        for attr, default, dtype in zip(
-            cls.__slots__, slot_defaults, slot_dtypes
-        ):
+        for attr, default, dtype in zip(cls.__slots__, slot_defaults, slot_dtypes):
             if attr == "_cache":
                 continue
             if attr not in kwargs:
@@ -460,7 +456,7 @@ class AbstractNdarray(np.ndarray, AbstractObject):
         """
         # Call the parent's __setstate__ with the other tuple elements.
         super(AbstractNdarray, self).__setstate__(
-            state[0:-len(self._iter_slots())]
+            state[0: -len(self._iter_slots())]
         )
 
         # set the __slot__ attributes
@@ -470,17 +466,18 @@ class AbstractNdarray(np.ndarray, AbstractObject):
         full information and thus need to be excluded from the __setstate__
         need to be in the same order as they have been added to __slots__
         """
-        added_slot_attrs = ['name']
+        added_slot_attrs = ["name"]
         n_np = 5  # number of numpy array states
         n_old = len(valid_slot_attrs) - len(state[n_np:])
         if n_old > 0:
             for latest_index in range(n_old):
                 new_slot = added_slot_attrs[-latest_index]
-                warnings.warn("Slots with names '{new_slot}' appears to have "
-                              "been added after the creation of the reduced "
-                              "state. No corresponding state found in "
-                              "__setstate__."
-                              .format(**locals()))
+                warnings.warn(
+                    "Slots with names '{new_slot}' appears to have "
+                    "been added after the creation of the reduced "
+                    "state. No corresponding state found in "
+                    "__setstate__.".format(**locals())
+                )
                 valid_slot_attrs.pop(valid_slot_attrs.index(new_slot))
                 setattr(self, new_slot, None)
 
@@ -498,9 +495,7 @@ class AbstractNdarray(np.ndarray, AbstractObject):
 
     @classmethod
     @contextmanager
-    def _bypass_setters(cls, *slots,
-                        empty_means_all=True,
-                        demand_existence=False):
+    def _bypass_setters(cls, *slots, empty_means_all=True, demand_existence=False):
         """
         Temporarily remove the setter in __slot_setters__ corresponding to slot
         position in __slot__. You should know what you do, when using this.
@@ -517,13 +512,11 @@ class AbstractNdarray(np.ndarray, AbstractObject):
         slot_indices = []
         setters = []
         for slot in slots:
-            slot_index = cls.__slots__.index(slot)\
-                if slot in cls.__slots__ else None
+            slot_index = cls.__slots__.index(slot) if slot in cls.__slots__ else None
             if slot_index is None:
                 # slot not in cls.__slots__.
                 if demand_existence:
-                    raise ValueError(
-                        "Slot {slot} not existing".format(**locals()))
+                    raise ValueError("Slot {slot} not existing".format(**locals()))
                 continue
             if len(cls.__slot_setters__) < slot_index + 1:
                 # no setter to be found
@@ -649,8 +642,9 @@ class Tensors(AbstractNdarray):
         Tensors([], shape=(0, 7), dtype=float64)
 
     """
-    __slots__ = ['coord_sys', 'name']
-    __slot_defaults__ = ['cartesian']
+
+    __slots__ = ["coord_sys", "name"]
+    __slot_defaults__ = ["cartesian"]
     __slot_setters__ = [tfields.bases.get_coord_system_name]
 
     def __new__(cls, tensors, **kwargs):
@@ -665,8 +659,8 @@ class Tensors(AbstractNdarray):
             coord_sys = kwargs.pop("coord_sys", tensors.coord_sys)
             tensors = tensors.copy()
             tensors.transform(coord_sys)
-            kwargs['coord_sys'] = coord_sys
-            kwargs['name'] = kwargs.pop('name', tensors.name)
+            kwargs["coord_sys"] = coord_sys
+            kwargs["name"] = kwargs.pop("name", tensors.name)
             if dtype is None:
                 dtype = tensors.dtype
         else:
@@ -681,8 +675,7 @@ class Tensors(AbstractNdarray):
             len(tensors)
         except TypeError:
             raise TypeError(
-                "Iterable structure necessary."
-                " Got {tensors}".format(**locals())
+                "Iterable structure necessary." " Got {tensors}".format(**locals())
             )
 
         """ process empty inputs """
@@ -694,12 +687,10 @@ class Tensors(AbstractNdarray):
             if issubclass(type(tensors), np.ndarray):
                 # np.empty
                 pass
-            elif hasattr(tensors, 'shape'):
+            elif hasattr(tensors, "shape"):
                 dim = dim(tensors)
             else:
-                raise ValueError(
-                    "Empty tensors need dimension parameter 'dim'."
-                )
+                raise ValueError("Empty tensors need dimension parameter 'dim'.")
 
         tensors = np.asarray(tensors, dtype=dtype, order=order)
         obj = tensors.view(cls)
@@ -831,14 +822,10 @@ class Tensors(AbstractNdarray):
                     pass
             if bases:
                 # get most frequent coord_sys
-                coord_sys = sorted(
-                    bases, key=Counter(bases).get, reverse=True
-                )[0]
+                coord_sys = sorted(bases, key=Counter(bases).get, reverse=True)[0]
                 kwargs["coord_sys"] = coord_sys
             else:
-                default = cls.__slot_defaults__[
-                    cls.__slots__.index("coord_sys")
-                ]
+                default = cls.__slot_defaults__[cls.__slots__.index("coord_sys")]
                 kwargs["coord_sys"] = default
 
         """ transform all raw inputs to cls type with correct coord_sys. Also
@@ -859,23 +846,24 @@ class Tensors(AbstractNdarray):
         for i, obj in enumerate(remainingObjects):
             tensors = np.append(tensors, obj, axis=0)
 
-        if len(tensors) == 0 and not kwargs.get('dim', None):
+        if len(tensors) == 0 and not kwargs.get("dim", None):
             # if you can not determine the tensor dimension, search for the
             # first object with some entries
-            kwargs['dim'] = dim(objects[0])
+            kwargs["dim"] = dim(objects[0])
 
         inst = cls.__new__(cls, tensors, **kwargs)
         if not return_templates:
             return inst
         else:
             tensor_lengths = [len(o) for o in objects]
-            cum_tensor_lengths = [sum(tensor_lengths[:i])
-                                  for i in range(len(objects))]
+            cum_tensor_lengths = [sum(tensor_lengths[:i]) for i in range(len(objects))]
             templates = [
                 tfields.TensorFields(
                     np.empty((len(obj), 0)),
-                    np.arange(tensor_lengths[i]) + cum_tensor_lengths[i])
-                for i, obj in enumerate(objects)]
+                    np.arange(tensor_lengths[i]) + cum_tensor_lengths[i],
+                )
+                for i, obj in enumerate(objects)
+            ]
             return inst, templates
 
     @classmethod
@@ -955,9 +943,7 @@ class Tensors(AbstractNdarray):
 
         """
         cls_kwargs = {
-            attr: kwargs.pop(attr)
-            for attr in list(kwargs)
-            if attr in cls.__slots__
+            attr: kwargs.pop(attr) for attr in list(kwargs) if attr in cls.__slots__
         }
         inst = cls.__new__(
             cls, tfields.lib.grid.igrid(*base_vectors, **kwargs), **cls_kwargs
@@ -1177,9 +1163,7 @@ class Tensors(AbstractNdarray):
                 + segment * periodicity / num_segments
             )
 
-    def equal(
-        self, other, rtol=None, atol=None, equal_nan=False, return_bool=True
-    ):
+    def equal(self, other, rtol=None, atol=None, equal_nan=False, return_bool=True):
         """
         Evaluate, whether the instance has the same content as other.
 
@@ -1190,10 +1174,7 @@ class Tensors(AbstractNdarray):
                 equal_nan (bool)
             see numpy.isclose
         """
-        if (
-            issubclass(type(other), Tensors)
-            and self.coord_sys != other.coord_sys
-        ):
+        if issubclass(type(other), Tensors) and self.coord_sys != other.coord_sys:
             other = other.copy()
             other.transform(self.coord_sys)
         x, y = np.asarray(self), np.asarray(other)
@@ -1419,13 +1400,13 @@ find-ordered-vector-in-numpy-array
         if template.fields and issubclass(type(self), TensorFields):
             template_field = np.array(template.fields[0])
             if len(self) > 0:
-                '''
+                """
                 if new vertices have been created in the template, it is
                 in principle unclear what fields we have to refer to.
                 Thus in creating the template, we gave np.nan.
                 To make it fast, we replace nan with 0 as a dummy and correct
                 the field entries afterwards with np.nan.
-                '''
+                """
                 nan_mask = np.isnan(template_field)
                 template_field[nan_mask] = 0  # dummy reference to index 0.
                 template_field = template_field.astype(int)
@@ -1527,10 +1508,7 @@ find-ordered-vector-in-numpy-array
             True
 
         """
-        if (
-            issubclass(type(other), Tensors)
-            and self.coord_sys != other.coord_sys
-        ):
+        if issubclass(type(other), Tensors) and self.coord_sys != other.coord_sys:
             other = other.copy()
             other.transform(self.coord_sys)
         return sp.spatial.distance.cdist(self, other, **kwargs)
@@ -1783,10 +1761,8 @@ class TensorFields(Tensors):
 
     """
 
-    __slots__ = ['coord_sys', 'name', 'fields']
-    __slot_setters__ = [tfields.bases.get_coord_system_name,
-                        None,
-                        as_tensors_list]
+    __slots__ = ["coord_sys", "name", "fields"]
+    __slot_setters__ = [tfields.bases.get_coord_system_name, None, as_tensors_list]
 
     def __new__(cls, tensors, *fields, **kwargs):
         rigid = kwargs.pop("rigid", True)
@@ -1806,9 +1782,7 @@ class TensorFields(Tensors):
             if not all([flen == olen for flen in field_lengths]):
                 raise ValueError(
                     "Length of base ({olen}) should be the same as"
-                    " the length of all fields ({field_lengths}).".format(
-                        **locals()
-                    )
+                    " the length of all fields ({field_lengths}).".format(**locals())
                 )
         return obj
 
@@ -1817,7 +1791,7 @@ class TensorFields(Tensors):
 
     def _kwargs(self):
         d = super()._kwargs()
-        d.pop('fields')
+        d.pop("fields")
         return d
 
     def __getitem__(self, index):
@@ -1865,8 +1839,7 @@ class TensorFields(Tensors):
                     index = index[0]
                 if item.fields:
                     # circumvent the setter here.
-                    with self._bypass_setters('fields',
-                                              demand_existence=False):
+                    with self._bypass_setters("fields", demand_existence=False):
                         item.fields = [
                             field.__getitem__(index) for field in item.fields
                         ]
@@ -1924,7 +1897,7 @@ class TensorFields(Tensors):
             )
 
         return_value = super(TensorFields, cls).merged(*objects, **kwargs)
-        return_templates = kwargs.get('return_templates', False)
+        return_templates = kwargs.get("return_templates", False)
         if return_templates:
             inst, templates = return_value
         else:
@@ -1964,8 +1937,11 @@ class TensorFields(Tensors):
     @names.setter
     def names(self, names):
         if not len(names) == len(self.fields):
-            raise ValueError("len(names) ({0}) != len(fields) ({1})"
-                             .format(len(names), len(self.fields)))
+            raise ValueError(
+                "len(names) ({0}) != len(fields) ({1})".format(
+                    len(names), len(self.fields)
+                )
+            )
         for i, name in enumerate(names):
             self.fields[i].name = name
 
@@ -2081,7 +2057,7 @@ class Container(Fields):
             self.append(item)
 
     def _kwargs(self):
-        return {'labels': self.labels}
+        return {"labels": self.labels}
 
 
 class Maps(sortedcontainers.SortedDict, AbstractObject):
@@ -2098,6 +2074,7 @@ class Maps(sortedcontainers.SortedDict, AbstractObject):
         )
 
     """
+
     def __init__(self, *args, **kwargs):
         if args and args[0] is None:
             # None key passed e.g. by copy. We do not change keys here.
@@ -2150,7 +2127,7 @@ class Maps(sortedcontainers.SortedDict, AbstractObject):
             else:
                 copy = True
         if copy:  # not else, because in case of wrong mp type we initialize
-            kwargs.setdefault('dtype', int)
+            kwargs.setdefault("dtype", int)
             mp = TensorFields(mp, *fields, **kwargs)
         return mp
 
@@ -2208,12 +2185,14 @@ class TensorMaps(TensorFields):
         >>> assert mesh_cp_cyl.coord_sys == tfields.bases.CYLINDER
 
     """
-    __slots__ = ['coord_sys', 'name', 'fields', 'maps']
-    __slot_setters__ = [tfields.bases.get_coord_system_name,
-                        None,
-                        as_tensors_list,
-                        as_maps,
-                        ]
+
+    __slots__ = ["coord_sys", "name", "fields", "maps"]
+    __slot_setters__ = [
+        tfields.bases.get_coord_system_name,
+        None,
+        as_tensors_list,
+        as_maps,
+    ]
 
     def __new__(cls, tensors, *fields, **kwargs):
         if issubclass(type(tensors), TensorMaps):
@@ -2300,15 +2279,13 @@ class TensorMaps(TensorFields):
                     map_mask = ~map_delete_mask
 
                     # build the correction counters
-                    move_up_counter = np.zeros(
-                        self.maps[map_dim].shape, dtype=int
-                    )
+                    move_up_counter = np.zeros(self.maps[map_dim].shape, dtype=int)
                     for p in delete_indices:
                         move_up_counter[self.maps[map_dim] > p] -= 1
 
-                    item.maps[map_dim] = (
-                        self.maps[map_dim] + move_up_counter
-                    )[map_mask]
+                    item.maps[map_dim] = (self.maps[map_dim] + move_up_counter)[
+                        map_mask
+                    ]
 
         return item
 
@@ -2317,16 +2294,13 @@ class TensorMaps(TensorFields):
         if not all([isinstance(o, cls) for o in objects]):
             # TODO: could allow if all face_fields are none
             raise TypeError(
-                "Merge constructor only accepts {cls} instances.".format(
-                    **locals()
-                )
+                "Merge constructor only accepts {cls} instances.".format(**locals())
             )
         tensor_lengths = [len(o) for o in objects]
-        cum_tensor_lengths = [sum(tensor_lengths[:i])
-                              for i in range(len(objects))]
+        cum_tensor_lengths = [sum(tensor_lengths[:i]) for i in range(len(objects))]
 
         return_value = super().merged(*objects, **kwargs)
-        return_templates = kwargs.get('return_templates', False)
+        return_templates = kwargs.get("return_templates", False)
         if return_templates:
             inst, templates = return_value
         else:
@@ -2344,18 +2318,17 @@ class TensorMaps(TensorFields):
         template_maps_list = [[] for i in range(len(objects))]
         for dimension in sorted(dim_maps_dict):
             # sort by object index
-            dim_maps = [dim_maps_dict[dimension][i]
-                        for i in range(len(objects))]
+            dim_maps = [dim_maps_dict[dimension][i] for i in range(len(objects))]
 
             return_value = TensorFields.merged(
-                *dim_maps,
-                return_templates=return_templates,
+                *dim_maps, return_templates=return_templates,
             )
             if return_templates:
                 mp, dimension_map_templates = return_value
                 for i in range(len(objects)):
-                    template_maps_list[i].append((dimension,
-                                                  dimension_map_templates[i]))
+                    template_maps_list[i].append(
+                        (dimension, dimension_map_templates[i])
+                    )
             else:
                 mp = return_value
             maps.append(mp)
@@ -2366,8 +2339,8 @@ class TensorMaps(TensorFields):
                 # template maps will not have dimensions according to their
                 # tensors which are indices
                 templates[i] = tfields.TensorMaps(
-                    templates[i],
-                    maps=Maps(template_maps))
+                    templates[i], maps=Maps(template_maps)
+                )
             return inst, templates
         else:
             return inst
@@ -2507,31 +2480,36 @@ class TensorMaps(TensorFields):
             copy of self without stale vertices and duplicat points (depending
             on arguments)
         """
-        # remove stale vertices
+        if not stale and not duplicates:
+            inst = self.copy()
         if stale:
-            stale_mask = self.stale()
-        else:
-            stale_mask = np.full(self.shape[0], False, dtype=bool)
-        # remove duplicates in order to not have any artificial separations
-        inst = self
+            # remove stale vertices i.e. those that are not referred by any
+            # map
+            remove_mask = self.stale()
+            inst = self.removed(remove_mask)
         if duplicates:
-            inst = self.copy()
-            duplicates = tfields.lib.util.duplicates(self, axis=0)
-            for tensor_index, duplicate_index in zip(
-                range(self.shape[0]), duplicates
-            ):
+            # remove duplicates in order to not have any artificial separations
+            if not stale:
+                # we have not yet made a copy but want to work on inst
+                inst = self.copy()
+            remove_mask = np.full(inst.shape[0], False, dtype=bool)
+            duplicates = tfields.lib.util.duplicates(inst, axis=0)
+            for tensor_index, duplicate_index in zip(range(inst.shape[0]), duplicates):
                 if duplicate_index != tensor_index:
-                    stale_mask[tensor_index] = True
-                    # redirect maps
-                    for map_dim in self.maps:
-                        for f in range(len(self.maps[map_dim])):  # face index
-                            mp = np.array(self.maps[map_dim], dtype=int)
+                    # mark duplicate at tensor_index for removal
+                    remove_mask[tensor_index] = True
+                    # redirect maps. Note: work on inst.maps instead of
+                    # self.maps in case stale vertices where removed
+                    for map_dim in inst.maps:
+                        for f in range(len(inst.maps[map_dim])):  # face index
+                            mp = np.array(inst.maps[map_dim], dtype=int)
                             if tensor_index in mp[f]:
-                                index = tfields.lib.util.index(mp[f],
-                                                               tensor_index)
+                                index = tfields.lib.util.index(mp[f], tensor_index)
                                 inst.maps[map_dim][f][index] = duplicate_index
-
-        return inst.removed(stale_mask)
+            if remove_mask.any():
+                # prevent another copy
+                inst = inst.removed(remove_mask)
+        return inst
 
     def removed(self, remove_condition):
         """
@@ -2604,7 +2582,6 @@ class TensorMaps(TensorFields):
             List(cls): One TensorMaps or TensorMaps subclass per
                 map_description
         """
-        # raise ValueError(map_descriptions)
         parts = []
         for map_description in map_descriptions:
             map_dim, map_indices_list = map_description
@@ -2739,8 +2716,7 @@ class TensorMaps(TensorFields):
         sorted_paths = []
         for path in paths:
             # find start index
-            values, counts = np.unique(path.maps[map_dim].flat,
-                                       return_counts=True)
+            values, counts = np.unique(path.maps[map_dim].flat, return_counts=True)
 
             first_node = None
             for v, c in zip(values, counts):
@@ -2789,4 +2765,5 @@ class TensorMaps(TensorFields):
 
 if __name__ == "__main__":  # pragma: no cover
     import doctest
+
     doctest.testmod()
diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py
index 03bfb6e743fc5257c6f2b367ab9e882d5b90d723..603be92afb0ec0837d18ff11692152168b03098e 100644
--- a/tfields/mesh3D.py
+++ b/tfields/mesh3D.py
@@ -18,7 +18,7 @@ import os
 
 
 def _dist_from_plane(point, plane):
-    return plane['normal'].dot(point) + plane['d']
+    return plane["normal"].dot(point) + plane["d"]
 
 
 def _segment_plane_intersection(p0, p1, plane):
@@ -92,8 +92,7 @@ def _intersect(triangle, plane, vertices_rejected):
     s2, d2 = _segment_plane_intersection(triangle[2], triangle[0], plane)
 
     single_index = index
-    couple_indices = [j for j in range(3)
-                      if not vertices_rejected[j] == lonely_bool]
+    couple_indices = [j for j in range(3) if not vertices_rejected[j] == lonely_bool]
 
     # TODO handle special cases. For now triangles with at least two points on plane are excluded
     new_points = None
@@ -110,14 +109,20 @@ def _intersect(triangle, plane, vertices_rejected):
     if lonely_bool:
         # two new triangles
         if len(s0) == 1 and len(s1) == 1:
-            new_points = [[couple_indices[0], s0[0], couple_indices[1]],
-                          [couple_indices[1], complex(1), s1[0]]]
+            new_points = [
+                [couple_indices[0], s0[0], couple_indices[1]],
+                [couple_indices[1], complex(1), s1[0]],
+            ]
         elif len(s1) == 1 and len(s2) == 1:
-            new_points = [[couple_indices[0], couple_indices[1], s1[0]],
-                          [couple_indices[0], complex(2), s2[0]]]
+            new_points = [
+                [couple_indices[0], couple_indices[1], s1[0]],
+                [couple_indices[0], complex(2), s2[0]],
+            ]
         elif len(s0) == 1 and len(s2) == 1:
-            new_points = [[couple_indices[0], couple_indices[1], s0[0]],
-                          [couple_indices[1], s2[0], complex(2)]]
+            new_points = [
+                [couple_indices[0], couple_indices[1], s0[0]],
+                [couple_indices[1], s2[0], complex(2)],
+            ]
     else:
         # one new triangle
         if len(s0) == 1 and len(s1) == 1:
@@ -191,19 +196,20 @@ class Mesh3D(tfields.TensorMaps):
         >>> assert np.array_equal(m1.faces, np.array([[0, 1, 2]]))
 
     """
+
     def __new__(cls, tensors, *fields, **kwargs):
-        kwargs['dim'] = 3
-        if 'maps' in kwargs and 'faces' in kwargs:
+        kwargs["dim"] = 3
+        if "maps" in kwargs and "faces" in kwargs:
             raise ValueError("Conflicting options maps and faces")
-        faces = kwargs.pop('faces', None)
-        maps = kwargs.pop('maps', None)
+        faces = kwargs.pop("faces", None)
+        maps = kwargs.pop("maps", None)
         if faces is not None:
             if len(faces) == 0:
                 # faces = []
                 faces = np.empty((0, 3))
             maps = [faces]
         if maps is not None:
-            kwargs['maps'] = maps
+            kwargs["maps"] = maps
         obj = super(Mesh3D, cls).__new__(cls, tensors, *fields, **kwargs)
         if len(obj.maps) > 1:
             raise ValueError("Mesh3D only allows one map")
@@ -215,28 +221,29 @@ class Mesh3D(tfields.TensorMaps):
         """
         Save obj as wavefront/.obj file
         """
-        obj = kwargs.pop('object', None)
-        group = kwargs.pop('group', None)
+        obj = kwargs.pop("object", None)
+        group = kwargs.pop("group", None)
 
-        cmap = kwargs.pop('cmap', 'viridis')
-        map_index = kwargs.pop('map_index', None)
+        cmap = kwargs.pop("cmap", "viridis")
+        map_index = kwargs.pop("map_index", None)
 
-        path = path.replace('.obj', '')
+        path = path.replace(".obj", "")
         directory, name = os.path.split(path)
 
         if map_index is not None:
             scalars = self.maps[3].fields[map_index]
             min_scalar = scalars[~np.isnan(scalars)].min()
             max_scalar = scalars[~np.isnan(scalars)].max()
-            vmin = kwargs.pop('vmin', min_scalar)
-            vmax = kwargs.pop('vmax', max_scalar)
+            vmin = kwargs.pop("vmin", min_scalar)
+            vmax = kwargs.pop("vmax", max_scalar)
             if vmin == vmax:
-                if vmin == 0.:
-                    vmax = 1.
+                if vmin == 0.0:
+                    vmax = 1.0
                 else:
-                    vmin = 0.
+                    vmin = 0.0
             import matplotlib.colors as colors
             import matplotlib.pyplot as plt
+
             norm = colors.Normalize(vmin, vmax)
             color_map = plt.get_cmap(cmap)
         else:
@@ -248,7 +255,7 @@ class Mesh3D(tfields.TensorMaps):
             raise ValueError("Unused arguments.")
 
         if norm is not None:
-            mat_name = name + '_frame_{0}.mat'.format(map_index)
+            mat_name = name + "_frame_{0}.mat".format(map_index)
             scalars[np.isnan(scalars)] = min_scalar - 1
             sorted_scalars = scalars[scalars.argsort()]
             sorted_scalars[sorted_scalars == min_scalar - 1] = np.nan
@@ -256,20 +263,21 @@ class Mesh3D(tfields.TensorMaps):
             scalar_set = np.unique(sorted_scalars)
             scalar_set[scalar_set == min_scalar - 1] = np.nan
             mat_path = os.path.join(directory, mat_name)
-            with open(mat_path, 'w') as mf:
+            with open(mat_path, "w") as mf:
                 for s in scalar_set:
                     if np.isnan(s):
                         mf.write("newmtl nan")
                         mf.write("Kd 0 0 0\n\n")
                     else:
                         mf.write("newmtl mtl_{0}\n".format(s))
-                        mf.write("Kd {c[0]} {c[1]} {c[2]}\n\n"
-                                 .format(c=color_map(norm(s))))
+                        mf.write(
+                            "Kd {c[0]} {c[1]} {c[2]}\n\n".format(c=color_map(norm(s)))
+                        )
         else:
             sorted_faces = self.faces
 
         # writing of the obj file
-        with open(path + '.obj', 'w') as f:
+        with open(path + ".obj", "w") as f:
             f.write("# File saved with tfields Mesh3D._save_obj method\n\n")
             if norm is not None:
                 f.write("mtllib ./{0}\n\n".format(mat_name))
@@ -303,48 +311,49 @@ class Mesh3D(tfields.TensorMaps):
         Given a path to a obj/wavefront file, construct the object
         """
         import csv
+
         log = logging.getLogger()
 
-        with open(path, mode='r') as f:
-            reader = csv.reader(f, delimiter=' ')
+        with open(path, mode="r") as f:
+            reader = csv.reader(f, delimiter=" ")
             groups = []
             group = None
             vertex_no = 1
             for line in reader:
                 if not line:
                     continue
-                if line[0] == '#':
+                if line[0] == "#":
                     continue
-                if line[0] == 'g':
+                if line[0] == "g":
                     if group:
                         groups.append(group)
                     group = dict(name=line[1], vertices={}, faces=[])
-                elif line[0] == 'v':
+                elif line[0] == "v":
                     if not group:
                         log.debug("No group found. Setting default 'Group'")
-                        group = dict(name='Group', vertices={}, faces=[])
+                        group = dict(name="Group", vertices={}, faces=[])
                     vertex = list(map(float, line[1:4]))
-                    group['vertices'][vertex_no] = vertex
+                    group["vertices"][vertex_no] = vertex
                     vertex_no += 1
-                elif line[0] == 'f':
+                elif line[0] == "f":
                     face = []
                     for v in line[1:]:
-                        w = v.split('/')
+                        w = v.split("/")
                         face.append(int(w[0]))
-                    group['faces'].append(face)
+                    group["faces"].append(face)
             else:
                 groups.append(group)
 
         vertices = []
         for g in groups[:]:
-            vertices.extend(g['vertices'].values())
+            vertices.extend(g["vertices"].values())
 
         if len(group_names) != 0:
-            groups = [g for g in groups if g['name'] in group_names]
+            groups = [g for g in groups if g["name"] in group_names]
 
         faces = []
         for g in groups:
-            faces.extend(g['faces'])
+            faces.extend(g["faces"])
         faces = np.add(np.array(faces), -1).tolist()
 
         """
@@ -358,8 +367,10 @@ class Mesh3D(tfields.TensorMaps):
             if length == 3:
                 continue
             if length == 4:
-                log.warning("Given a Rectangle. I will split it but "
-                            "sometimes the order is different.")
+                log.warning(
+                    "Given a Rectangle. I will split it but "
+                    "sometimes the order is different."
+                )
                 faces.insert(i + 1, faces[i][2:] + faces[i][:1])
                 faces[i] = faces[i][:3]
             else:
@@ -465,8 +476,7 @@ class Mesh3D(tfields.TensorMaps):
         for ind in indices:
             for coord in coords:
                 basePart = base_vectors[:]
-                basePart[coord] = np.array([base_vectors[coord][ind]],
-                                           dtype=float)
+                basePart[coord] = np.array([base_vectors[coord][ind]], dtype=float)
                 planes.append(cls.plane(*basePart, **kwargs))
         inst = cls.merged(*planes, **kwargs)
         return inst
@@ -476,8 +486,9 @@ class Mesh3D(tfields.TensorMaps):
         if self.maps:
             return self.maps[3]
         else:
-            logging.warning("No faces found. Mesh has {x} vertices."
-                            .format(x=len(self)))
+            logging.warning(
+                "No faces found. Mesh has {x} vertices.".format(x=len(self))
+            )
             return tfields.Maps.to_map([], dim=3)
 
     @faces.setter
@@ -533,16 +544,15 @@ class Mesh3D(tfields.TensorMaps):
         self.tree or setting it to self.tree = <saved tree> before
         calling in_faces
         """
-        key = 'mesh_tree'
-        if hasattr(self, '_cache') and key in self._cache:
+        key = "mesh_tree"
+        if hasattr(self, "_cache") and key in self._cache:
             log = logging.getLogger()
-            log.info(
-                "Using cached decision tree to speed up point - face mapping.")
-            masks = self.tree.in_faces(
-                points, delta, assign_multiple=assign_multiple)
+            log.info("Using cached decision tree to speed up point - face mapping.")
+            masks = self.tree.in_faces(points, delta, assign_multiple=assign_multiple)
         else:
             masks = self.triangles().in_triangles(
-                points, delta, assign_multiple=assign_multiple)
+                points, delta, assign_multiple=assign_multiple
+            )
         return masks
 
     @property
@@ -561,9 +571,9 @@ class Mesh3D(tfields.TensorMaps):
             # >>> assert mask.sum() == 1  # one point in one triangle
         """
         raise ValueError("Broken feature. We are working on it!")
-        if not hasattr(self, '_cache'):
+        if not hasattr(self, "_cache"):
             self._cache = {}
-        key = 'mesh_tree'
+        key = "mesh_tree"
         if key in self._cache:
             tree = self._cache[key]
         else:
@@ -573,9 +583,9 @@ class Mesh3D(tfields.TensorMaps):
 
     @tree.setter
     def tree(self, tree):
-        if not hasattr(self, '_cache'):
+        if not hasattr(self, "_cache"):
             self._cache = {}
-        key = 'mesh_tree'
+        key = "mesh_tree"
         self._cache[key] = tree
 
     def remove_faces(self, face_delete_mask):
@@ -616,17 +626,19 @@ class Mesh3D(tfields.TensorMaps):
                 scalars.append(face_indices[face_mask][0])
             inst.maps[3].fields = [tfields.Tensors(scalars, dim=1)]
         else:
-            inst.maps = [tfields.TensorFields([],
-                                              tfields.Tensors([], dim=1),
-                                              dim=3,
-                                              dtype=int)
-                         ]
+            inst.maps = [
+                tfields.TensorFields([], tfields.Tensors([], dim=1), dim=3, dtype=int)
+            ]
         return inst
 
-    def project(self, tensor_field,
-                delta=None, merge_functions=None,
-                point_face_assignment=None,
-                return_point_face_assignment=False):
+    def project(
+        self,
+        tensor_field,
+        delta=None,
+        merge_functions=None,
+        point_face_assignment=None,
+        return_point_face_assignment=False,
+    ):
         """
         project the points of the tensor_field to a copy of the mesh
         and set the face values accord to the field to the maps field.
@@ -698,7 +710,7 @@ class Mesh3D(tfields.TensorMaps):
         # setup empty map fields and collect fields
         n_faces = len(self.maps[3])
         point_indices = np.arange(len(tensor_field))
-        if not hasattr(tensor_field, 'fields') or len(tensor_field.fields) == 0:
+        if not hasattr(tensor_field, "fields") or len(tensor_field.fields) == 0:
             # if not fields is existing use int type fields and empty_map_fields
             # in order to generate a sum
             fields = [np.full(len(tensor_field), 1, dtype=int)]
@@ -712,8 +724,7 @@ class Mesh3D(tfields.TensorMaps):
                 cls = type(field)
                 kwargs = {key: getattr(field, key) for key in cls.__slots__}
                 shape = (n_faces,) + field.shape[1:]
-                empty_map_fields.append(cls(np.full(shape, np.nan),
-                                            **kwargs))
+                empty_map_fields.append(cls(np.full(shape, np.nan), **kwargs))
             if merge_functions is None:
                 merge_functions = [lambda x: np.mean(x, axis=0)] * len(fields)
 
@@ -726,17 +737,20 @@ class Mesh3D(tfields.TensorMaps):
             mask = self.in_faces(tensor_field, delta=delta)
 
             face_indices = np.arange(n_faces)
-            point_face_assignment = [face_indices[mask[p_index, :]]
-                                     for p_index in point_indices]
-            point_face_assignment = np.array([fa if len(fa) != 0 else [-1]
-                                              for fa in point_face_assignment])
+            point_face_assignment = [
+                face_indices[mask[p_index, :]] for p_index in point_indices
+            ]
+            point_face_assignment = np.array(
+                [fa if len(fa) != 0 else [-1] for fa in point_face_assignment]
+            )
             point_face_assignment = point_face_assignment.reshape(-1)
         point_face_assignment_set = set(point_face_assignment)
 
         # merge the fields according to point_face_assignment
         map_fields = []
-        for field, map_field, merge_function in \
-                zip(fields, empty_map_fields, merge_functions):
+        for field, map_field, merge_function in zip(
+            fields, empty_map_fields, merge_functions
+        ):
             for i, f_index in enumerate(point_face_assignment_set):
                 if f_index == -1:
                     # point could not be mapped
@@ -764,10 +778,10 @@ class Mesh3D(tfields.TensorMaps):
 
         inst = self.copy()
 
-        '''
+        """
         add the indices of the vertices and maps to the fields. They will be
         removed afterwards
-        '''
+        """
         if not _in_recursion:
             inst.fields.append(tfields.Tensors(np.arange(len(inst))))
             for mp in inst.maps.values():
@@ -783,14 +797,14 @@ class Mesh3D(tfields.TensorMaps):
         elif all(~mask):
             # all vertices are valid
             inst = inst[mask]
-        elif at_intersection == 'keep':
+        elif at_intersection == "keep":
             expression_parts = tfields.lib.symbolics.split_expression(expression)
             if len(expression_parts) > 1:
                 new_mesh = inst.copy()
                 for exprPart in expression_parts:
-                    inst, _ = inst._cut_sympy(exprPart,
-                                              at_intersection=at_intersection,
-                                              _in_recursion=True)
+                    inst, _ = inst._cut_sympy(
+                        exprPart, at_intersection=at_intersection, _in_recursion=True
+                    )
             elif len(expression_parts) == 1:
                 face_delete_indices = set([])
                 for i, face in enumerate(inst.maps[3]):
@@ -810,14 +824,14 @@ class Mesh3D(tfields.TensorMaps):
             else:
                 raise ValueError("Sympy expression is not splitable.")
             inst = inst.cleaned()
-        elif at_intersection == 'split' or at_intersection == 'splitRough':
-            '''
+        elif at_intersection == "split" or at_intersection == "split_rough":
+            """
             add vertices and faces that are at the border of the cuts
-            '''
+            """
             expression_parts = tfields.lib.symbolics.split_expression(expression)
             if len(expression_parts) > 1:
                 new_mesh = inst.copy()
-                if at_intersection == 'splitRough':
+                if at_intersection == "split_rough":
                     """
                     the following is, to speed up the process. Problem is, that
                     triangles can exist, where all points lie outside the cut,
@@ -828,30 +842,37 @@ class Mesh3D(tfields.TensorMaps):
                     face_inters_mask = np.full((inst.faces.shape[0]), False, dtype=bool)
                     for i, face in enumerate(inst.faces):
                         vertices_rejected = [-mask[f] for f in face]
-                        face_on_edge = any(vertices_rejected) and not all(vertices_rejected)
+                        face_on_edge = any(vertices_rejected) and not all(
+                            vertices_rejected
+                        )
                         if face_on_edge:
                             face_inters_mask[i] = True
                     new_mesh.remove_faces(-face_inters_mask)
 
                 for exprPart in expression_parts:
-                    inst, _ = inst._cut_sympy(exprPart,
-                                              at_intersection='split',
-                                              _in_recursion=True)
+                    inst, _ = inst._cut_sympy(
+                        exprPart, at_intersection="split", _in_recursion=True
+                    )
             elif len(expression_parts) == 1:
-                points = [sympy.symbols('x0, y0, z0'),
-                          sympy.symbols('x1, y1, z1'),
-                          sympy.symbols('x2, y2, z2')]
+                points = [
+                    sympy.symbols("x0, y0, z0"),
+                    sympy.symbols("x1, y1, z1"),
+                    sympy.symbols("x2, y2, z2"),
+                ]
                 plane_sympy = tfields.lib.symbolics.to_plane(expression)
                 norm_sympy = np.array(plane_sympy.normal_vector).astype(float)
                 d = -norm_sympy.dot(np.array(plane_sympy.p1).astype(float))
-                plane = {'normal': norm_sympy, 'd': d}
+                plane = {"normal": norm_sympy, "d": d}
 
                 norm_vectors = inst.triangles().norms()
                 new_points = np.empty((0, 3))
                 new_faces = np.empty((0, 3))
-                new_fields = [tfields.Tensors(np.empty((0,) + field.shape[1:]),
-                                              coord_sys=field.coord_sys)
-                              for field in inst.fields]
+                new_fields = [
+                    tfields.Tensors(
+                        np.empty((0,) + field.shape[1:]), coord_sys=field.coord_sys
+                    )
+                    for field in inst.fields
+                ]
                 new_map_fields = [[] for field in inst.maps[3].fields]
                 new_norm_vectors = []
                 newScalarMap = []
@@ -860,8 +881,7 @@ class Mesh3D(tfields.TensorMaps):
                 vertices = np.array(inst)
                 faces = np.array(inst.maps[3])
                 fields = [np.array(field) for field in inst.fields]
-                faces_fields = [np.array(field)
-                                for field in inst.maps[3].fields]
+                faces_fields = [np.array(field) for field in inst.maps[3].fields]
 
                 face_delete_indices = set([])
                 for i, face in enumerate(inst.maps[3]):
@@ -883,7 +903,8 @@ class Mesh3D(tfields.TensorMaps):
                         Add the intersection points and faces
                         """
                         intersection = _intersect(
-                            triangle_points, plane, vertices_rejected)
+                            triangle_points, plane, vertices_rejected
+                        )
                         last_idx = len(vertices) - 1
                         for tri_list in intersection:
                             new_face = []
@@ -899,28 +920,29 @@ class Mesh3D(tfields.TensorMaps):
                                     # new vertex
                                     new_face.append(len(vertices))
                                     vertices = np.append(
-                                        vertices,
-                                        [[float(x) for x in item]],
-                                        axis=0)
+                                        vertices, [[float(x) for x in item]], axis=0
+                                    )
                                     fields = [
-                                        np.append(field,
-                                                  np.full((1,) + field.shape[1:], np.nan),
-                                                  axis=0)
-                                        for field in fields]
+                                        np.append(
+                                            field,
+                                            np.full((1,) + field.shape[1:], np.nan),
+                                            axis=0,
+                                        )
+                                        for field in fields
+                                    ]
                             faces = np.append(faces, [new_face], axis=0)
-                            faces_fields = [np.append(field,
-                                                      [field[i]],
-                                                      axis=0)
-                                            for field in faces_fields]
+                            faces_fields = [
+                                np.append(field, [field[i]], axis=0)
+                                for field in faces_fields
+                            ]
                             faces_fields[-1][-1] = i
 
                 face_map = tfields.TensorFields(
-                    faces, *faces_fields,
-                    dtype=int,
-                    coord_sys=inst.maps[3].coord_sys)
-                inst = tfields.Mesh3D(vertices, *fields,
-                                      maps=[face_map],
-                                      coord_sys=inst.coord_sys)
+                    faces, *faces_fields, dtype=int, coord_sys=inst.maps[3].coord_sys
+                )
+                inst = tfields.Mesh3D(
+                    vertices, *fields, maps=[face_map], coord_sys=inst.coord_sys
+                )
                 mask = np.full(len(inst.maps[3]), True, dtype=bool)
                 for face_idx in range(len(inst.maps[3])):
                     if face_idx in face_delete_indices:
@@ -929,11 +951,13 @@ class Mesh3D(tfields.TensorMaps):
             else:
                 raise ValueError("Sympy expression is not splitable.")
             inst = inst.cleaned()
-        elif at_intersection == 'remove':
+        elif at_intersection == "remove":
             inst = inst[mask]
         else:
-            raise AttributeError("No at_intersection method called {at_intersection} "
-                                 "implemented".format(**locals()))
+            raise AttributeError(
+                "No at_intersection method called {at_intersection} "
+                "implemented".format(**locals())
+            )
 
         if _in_recursion:
             template = None
@@ -941,12 +965,11 @@ class Mesh3D(tfields.TensorMaps):
             template_field = inst.fields.pop(-1)
             template_maps = []
             for mp in inst.maps.values():
-                t_mp = tfields.TensorFields(tfields.Tensors(mp),
-                                            mp.fields.pop(-1))
+                t_mp = tfields.TensorFields(tfields.Tensors(mp), mp.fields.pop(-1))
                 template_maps.append(t_mp)
-            template = tfields.Mesh3D(tfields.Tensors(inst),
-                                      template_field,
-                                      maps=template_maps)
+            template = tfields.Mesh3D(
+                tfields.Tensors(inst), template_field, maps=template_maps
+            )
         return inst, template
 
     def _cut_template(self, template):
@@ -990,13 +1013,13 @@ class Mesh3D(tfields.TensorMaps):
         if template.fields:
             template_field = np.array(template.fields[0])
             if len(self) > 0:
-                '''
+                """
                 if new vertices have been created in the template, it is
                 in principle unclear what fields we have to refer to.
                 Thus in creating the template, we gave np.nan.
                 To make it fast, we replace nan with 0 as a dummy and correct
                 the field entries afterwards with np.nan.
-                '''
+                """
                 nan_mask = np.isnan(template_field)
                 template_field[nan_mask] = 0  # dummy reference to index 0.
                 template_field = template_field.astype(int)
@@ -1014,13 +1037,10 @@ class Mesh3D(tfields.TensorMaps):
                     mp_fields.append(field[0:0])  # np.empty
                 else:
                     mp_fields.append(field[template_mp.fields[0].astype(int)])
-            new_mp = tfields.TensorFields(tfields.Tensors(template_mp),
-                                          *mp_fields)
+            new_mp = tfields.TensorFields(tfields.Tensors(template_mp), *mp_fields)
             maps.append(new_mp)
 
-        inst = tfields.Mesh3D(tfields.Tensors(template),
-                              *fields,
-                              maps=maps)
+        inst = tfields.Mesh3D(tfields.Tensors(template), *fields, maps=maps)
         return inst
 
     def cut(self, *args, **kwargs):
@@ -1149,9 +1169,7 @@ class Mesh3D(tfields.TensorMaps):
             templates = []
             for i, part in enumerate(parts):
                 template = part.copy()
-                template.maps[3].fields = [
-                    tfields.Tensors(mp_description[1][i])
-                ]
+                template.maps[3].fields = [tfields.Tensors(mp_description[1][i])]
                 templates.append(template)
             return parts, templates
 
@@ -1159,17 +1177,19 @@ class Mesh3D(tfields.TensorMaps):
         """
         Forwarding to rna.plotting.plot_mesh
         """
-        scalars_demanded = 'color' not in kwargs \
-            and 'facecolors' not in kwargs \
-            and any([v in kwargs for v in ['vmin', 'vmax', 'cmap']])
-        map_index = kwargs.pop('map_index',
-                               None if not scalars_demanded else 0)
+        scalars_demanded = (
+            "color" not in kwargs
+            and "facecolors" not in kwargs
+            and any([v in kwargs for v in ["vmin", "vmax", "cmap"]])
+        )
+        map_index = kwargs.pop("map_index", None if not scalars_demanded else 0)
         if map_index is not None:
             if not len(self.maps[3]) == 0:
-                kwargs['color'] = self.maps[3].fields[map_index]
+                kwargs["color"] = self.maps[3].fields[map_index]
         return rna.plotting.plot_mesh(self, self.faces, **kwargs)
 
 
-if __name__ == '__main__':  # pragma: no cover
+if __name__ == "__main__":  # pragma: no cover
     import doctest
+
     doctest.testmod()