From 6ebe8a72ea697f6676b72b20d1d6e4d4065eaf8d Mon Sep 17 00:00:00 2001 From: Philipp Arras <c@philipp-arras.de> Date: Tue, 11 Mar 2025 20:21:13 +0100 Subject: [PATCH] check_operator: be cautious about dtypes --- src/extra.py | 4 ++++ src/field.py | 3 +++ src/multi_field.py | 4 ++++ 3 files changed, 11 insertions(+) diff --git a/src/extra.py b/src/extra.py index b908676e4..d32aeb52e 100644 --- a/src/extra.py +++ b/src/extra.py @@ -318,10 +318,14 @@ def _get_acceptable_location(op, loc, lin): direction = direction * (lin.val.norm() * fac) else: direction = direction * (lin.val.norm() * fac / dirder.norm()) + direction = direction.astype(loc.dtype) + assert direction.dtype == loc.dtype + # Find a step length that leads to a "reasonable" location for i in range(50): try: loc2 = loc + direction + assert loc2.dtype == loc.dtype lin2 = op(Linearization.make_var(loc2, lin.want_metric)) if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20: break diff --git a/src/field.py b/src/field.py index d854f493d..0be66a1ed 100644 --- a/src/field.py +++ b/src/field.py @@ -175,6 +175,9 @@ class Field(Operator): """type : the data type of the field's entries""" return self._val.dtype + def astype(self, dtype): + return Field(self._domain, np.astype(self._val, dtype)) + @property def domain(self): """DomainTuple : the field's domain""" diff --git a/src/multi_field.py b/src/multi_field.py index e87573901..a4c6df681 100644 --- a/src/multi_field.py +++ b/src/multi_field.py @@ -86,6 +86,10 @@ class MultiField(Operator): def dtype(self): return {key: val.dtype for key, val in self.items()} + def astype(self, dtype): + val = (vv.astype(dtype[kk]) for vv, kk in zip(self._val, self._domain.keys())) + return MultiField(self._domain, tuple(val)) + def _transform(self, op): return MultiField(self._domain, tuple(op(v) for v in self._val)) -- GitLab