From 17aa5a5d0a3f3073a385fc9543a939d0441ad1ab Mon Sep 17 00:00:00 2001
From: "Boeckenhoff, Daniel (dboe)" <daniel.boeckenhoff@ipp.mpg.de>
Date: Fri, 26 Apr 2019 16:12:46 +0200
Subject: [PATCH] added project functionality to mesh3d

---
 tfields/mesh3D.py | 92 +++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 90 insertions(+), 2 deletions(-)

diff --git a/tfields/mesh3D.py b/tfields/mesh3D.py
index 62c4b0f..45d7485 100644
--- a/tfields/mesh3D.py
+++ b/tfields/mesh3D.py
@@ -564,11 +564,12 @@ class Mesh3D(tfields.TensorMaps):
             Mesh3D: template (see cut), can be used as template to retrieve
                 sub_mesh from self instance
         Examples:
+            >>> import tfields
+            >>> from sympy.abc import y
             >>> mp = tfields.TensorFields([[0,1,2],[2,3,0],[3,2,5],[5,4,3]],
             ...                           [1, 2, 3, 4])
             >>> m = tfields.Mesh3D([[0,0,0], [1,0,0], [1,1,0], [0,1,0], [0,2,0], [1,2,0]],
             ...                     maps=[mp])
-            >>> from sympy.abc import y
             >>> m_cut = m.cut(y < 1.5, at_intersection='split')
             >>> template = m.template(m_cut)
             >>> assert m_cut.equal(m.cut(template))
@@ -593,6 +594,93 @@ class Mesh3D(tfields.TensorMaps):
                         ]
         return inst
 
+    def project(self, tensor_field,
+                delta=None, merge_functions=None):
+        """
+        project the points of the tensor_field to a copy of the mesh
+        and set the face values accord to the field to the maps field.
+        If no field is present in tensor_field, the number of points in a mesh
+        is counted.
+
+        Args:
+            tensor_field (Tensors | TensorFields)
+            delta (float | None): forwarded to Mesh3D.in_faces
+            merge_functions (callable): if multiple Tensors lie in the same face,
+                they are mapped with the merge_function to one value
+
+        Examples:
+            >>> import tfields
+            >>> mp = tfields.TensorFields([[0,1,2],[2,3,0],[3,2,5],[5,4,3]],
+            ...                           [1, 2, 3, 4])
+            >>> m = tfields.Mesh3D([[0,0,0], [1,0,0], [1,1,0], [0,1,0], [0,2,0], [1,2,0]],
+            ...                     maps=[mp])
+
+            Projecting points onto the mesh gives the count
+            >>> points = tfields.Tensors([[0.5, 0.2, 0.0], [0.5, 0.02, 0.0], [0.5, 0.8, 0.0]])
+            >>> m_points = m.project(points)
+            >>> assert m_points.maps[0].fields[0].equal([2, 1, 0, 0])
+
+            TensorFields with arbitrary size are projected,
+            combinging the fields automatically
+            >>> fields = [tfields.Tensors([1,3,42]),
+            ...           tfields.Tensors([[0,1,2], [2,3,4], [3,4,5]]),
+            ...           tfields.Tensors([[[0, 0]] * 2,
+            ...                            [[2, 2]] * 2,
+            ...                            [[3, 3]] * 2])]
+            >>> tf = tfields.TensorFields(points, *fields)
+            >>> m_tf = m.project(tf)
+            >>> assert m_tf.maps[0].fields[0].equal([2, 42, np.nan, np.nan],
+            ...                                     equal_nan=True)
+            >>> assert m_tf.maps[0].fields[1].equal([[1, 2, 3],
+            ...                                      [3, 4, 5],
+            ...                                      [np.nan] * 3,
+            ...                                      [np.nan] * 3],
+            ...                                     equal_nan=True)
+            >>> assert m_tf.maps[0].fields[2].equal([[[1, 1]] * 2,
+            ...                                      [[3, 3]] * 2,
+            ...                                      [[np.nan, np.nan]] * 2,
+            ...                                      [[np.nan, np.nan]] * 2],
+            ...                                     equal_nan=True)
+
+        """
+        if not issubclass(type(tensor_field), tfields.Tensors):
+            tensor_field = tfields.TensorFields(tensor_field)
+        mask = self.in_faces(tensor_field, delta=None)
+        inst = self.copy()
+        
+        n_faces = len(self.maps[0])
+        if not hasattr(tensor_field, 'fields') or len(tensor_field.fields) == 0:
+            fields = [np.full(len(tensor_field), 1)]
+            empty_map_fields = [tfields.Tensors(np.full(n_faces, 0))]
+            if merge_functions is None:
+                merge_functions = [np.sum]
+        else:
+            fields = tensor_field.fields
+            empty_map_fields = []
+            for field in fields:
+                cls = type(field)
+                kwargs = {key: getattr(field, key) for key in cls.__slots__}
+                shape = (n_faces,) + field.shape[1:]
+                empty_map_fields.append(cls(np.full(shape, np.nan),
+                                            **kwargs))
+            if merge_functions is None:
+                merge_functions = [lambda x: np.mean(x, axis=0)] * len(fields)
+
+        map_fields = []
+        for field, map_field, merge_function in \
+                zip(fields, empty_map_fields, merge_functions):
+            for f in range(len(self.maps[0])):
+                res = field[mask[:, f]]
+                if len(res) == 0:
+                    continue
+                elif len(res) == 1:
+                    map_field[f] = res
+                else:
+                    map_field[f] = merge_function(res)
+            map_fields.append(map_field)
+        inst.maps[0].fields = map_fields
+        return inst
+
     def _cut_sympy(self, expression, at_intersection="remove", _in_recursion=False):
         """
         Partition the mesh with the cuts given and return the template
@@ -1010,5 +1098,5 @@ class Mesh3D(tfields.TensorMaps):
 if __name__ == '__main__':  # pragma: no cover
     import doctest
 
-    doctest.run_docstring_examples(Mesh3D.cut, globals())
+    doctest.run_docstring_examples(Mesh3D.project, globals())
     # doctest.testmod()
-- 
GitLab