From 6ebe8a72ea697f6676b72b20d1d6e4d4065eaf8d Mon Sep 17 00:00:00 2001
From: Philipp Arras <c@philipp-arras.de>
Date: Tue, 11 Mar 2025 20:21:13 +0100
Subject: [PATCH] check_operator: be cautious about dtypes

---
 src/extra.py       | 4 ++++
 src/field.py       | 3 +++
 src/multi_field.py | 4 ++++
 3 files changed, 11 insertions(+)

diff --git a/src/extra.py b/src/extra.py
index b908676e4..d32aeb52e 100644
--- a/src/extra.py
+++ b/src/extra.py
@@ -318,10 +318,14 @@ def _get_acceptable_location(op, loc, lin):
         direction = direction * (lin.val.norm() * fac)
     else:
         direction = direction * (lin.val.norm() * fac / dirder.norm())
+    direction = direction.astype(loc.dtype)
+    assert direction.dtype == loc.dtype
+
     # Find a step length that leads to a "reasonable" location
     for i in range(50):
         try:
             loc2 = loc + direction
+            assert loc2.dtype == loc.dtype
             lin2 = op(Linearization.make_var(loc2, lin.want_metric))
             if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20:
                 break
diff --git a/src/field.py b/src/field.py
index d854f493d..0be66a1ed 100644
--- a/src/field.py
+++ b/src/field.py
@@ -175,6 +175,9 @@ class Field(Operator):
         """type : the data type of the field's entries"""
         return self._val.dtype
 
+    def astype(self, dtype):
+        return Field(self._domain, np.astype(self._val, dtype))
+
     @property
     def domain(self):
         """DomainTuple : the field's domain"""
diff --git a/src/multi_field.py b/src/multi_field.py
index e87573901..a4c6df681 100644
--- a/src/multi_field.py
+++ b/src/multi_field.py
@@ -86,6 +86,10 @@ class MultiField(Operator):
     def dtype(self):
         return {key: val.dtype for key, val in self.items()}
 
+    def astype(self, dtype):
+        val = (vv.astype(dtype[kk]) for vv, kk in zip(self._val, self._domain.keys()))
+        return MultiField(self._domain, tuple(val))
+
     def _transform(self, op):
         return MultiField(self._domain, tuple(op(v) for v in self._val))
 
-- 
GitLab