Commit 71262d82 authored by Philipp Arras's avatar Philipp Arras

Add dtype checks

parent 9a2cc287
Pipeline #62089 passed with stages
in 9 minutes and 6 seconds
......@@ -18,7 +18,9 @@ import scipy.sparse.linalg as ssl
from .domain_tuple import DomainTuple
from .domains.unstructured_domain import UnstructuredDomain
from .field import Field
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.linear_operator import LinearOperator
from .operators.sandwich_operator import SandwichOperator
from .sugar import from_global_data, makeDomain
......@@ -52,12 +54,14 @@ class _DomRemover(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
self._check_float_dtype(x)
x = x.to_global_data()
if isinstance(self._domain, DomainTuple):
res = x.ravel() if mode == self.TIMES else x.reshape(
self._domain.shape)
else:
res = np.empty(self.target.shape, dtype=x.dtype) if mode == self.TIMES else {}
res = np.empty(self.target.shape) if mode == self.TIMES else {}
for ii, (kk, dd) in enumerate(self.domain.items()):
i0, i1 = self._size_array[ii:ii + 2]
if mode == self.TIMES:
......@@ -66,6 +70,18 @@ class _DomRemover(LinearOperator):
res[kk] = x[i0:i1].reshape(dd.shape)
return from_global_data(self._tgt(mode), res)
@staticmethod
def _check_float_dtype(fld):
if isinstance(fld, MultiField):
dts = [ff.local_data.dtype for ff in fld.values()]
elif isinstance(fld, Field):
dts = [fld.local_data.dtype]
else:
raise TypeError
for dt in dts:
if not np.issubdtype(dt, np.float64):
raise TypeError('Operator supports only floating point dtypes')
def operator_spectrum(A, k, hermitian, which='LM', tol=0):
'''
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment