Commit 3fb9c0fb authored by Martin Reinecke's avatar Martin Reinecke
Browse files

faster type tests

parent 62d56279
......@@ -219,14 +219,14 @@ class Field(object):
@property
def real(self):
"""Field : The real part of the field"""
if not np.issubdtype(self.dtype, np.complexfloating):
return self
return Field(self._domain, self._val.real)
if utilities.iscomplextype(self.dtype):
return Field(self._domain, self._val.real)
return self
@property
def imag(self):
"""Field : The imaginary part of the field"""
if not np.issubdtype(self.dtype, np.complexfloating):
if not utilities.iscomplextype(self.dtype):
raise ValueError(".imag called on a non-complex Field")
return Field(self._domain, self._val.imag)
......@@ -384,7 +384,7 @@ class Field(object):
Field
The complex conjugated field.
"""
if np.issubdtype(self._val.dtype, np.complexfloating):
if utilities.iscomplextype(self._val.dtype):
return Field(self._domain, self._val.conjugate())
return self
......@@ -571,7 +571,7 @@ class Field(object):
return self._contraction_helper('var', spaces)
# MR FIXME: not very efficient or accurate
m1 = self.mean(spaces)
if np.issubdtype(self.dtype, np.complexfloating):
if utilities.iscomplextype(self.dtype):
sq = abs(self-m1)**2
else:
sq = (self-m1)**2
......
......@@ -98,7 +98,7 @@ class DiagonalOperator(EndomorphicOperator):
def _fill_rest(self):
self._ldiag.flags.writeable = False
self._complex = np.issubdtype(self._ldiag.dtype, np.complexfloating)
self._complex = utilities.iscomplextype(self._ldiag.dtype)
if not self._complex:
lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
......
......@@ -80,7 +80,7 @@ class HartleyOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if np.issubdtype(x.dtype, np.complexfloating):
if utilities.iscomplextype(x.dtype):
return (self._apply_cartesian(x.real, mode) +
1j*self._apply_cartesian(x.imag, mode))
else:
......
......@@ -84,7 +84,7 @@ class ScalingOperator(EndomorphicOperator):
if trafo == 0:
return self
if trafo == ADJ and np.issubdtype(type(self._factor), np.floating):
if trafo == ADJ and not np.iscomplex(self._factor):
return self
if trafo == ADJ:
return ScalingOperator(np.conj(self._factor), self._domain)
......
......@@ -87,7 +87,7 @@ class SHTOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if np.issubdtype(x.dtype, np.complexfloating):
if utilities.iscomplextype(x.dtype):
return (self._apply_spherical(x.real, mode) +
1j*self._apply_spherical(x.imag, mode))
else:
......
......@@ -136,7 +136,7 @@ def power_analyze(field, spaces=None, binbounds=None,
if len(spaces) == 0:
raise ValueError("No space for analysis specified.")
field_real = not np.issubdtype(field.dtype, np.complexfloating)
field_real = not utilities.iscomplextype(field.dtype)
if (not field_real) and keep_phase_information:
raise ValueError("cannot keep phase from real-valued input Field")
......
......@@ -30,7 +30,7 @@ from .compat import *
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product", "frozendict", "special_add_at"]
"my_product", "frozendict", "special_add_at", "iscomplextype"]
def my_sum(iterable):
......@@ -212,7 +212,7 @@ def hartley(a, axes=None):
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if np.issubdtype(a.dtype, np.complexfloating):
if iscomplextype(a.dtype):
raise TypeError("Hartley transform requires real-valued arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
......@@ -256,7 +256,7 @@ def my_fftn_r2c(a, axes=None):
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape")
if np.issubdtype(a.dtype, np.complexfloating):
if iscomplextype(a.dtype):
raise TypeError("Transform requires real-valued input arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
......@@ -345,7 +345,7 @@ def special_add_at(a, axis, index, b):
sz3 = int(np.prod(a.shape[axis+1:]))
a2 = a.reshape([sz1, -1, sz3])
b2 = b.reshape([sz1, -1, sz3])
if np.issubdtype(a.dtype, np.complexfloating):
if iscomplextype(a.dtype):
dt2 = a.real.dtype
a2 = a2.view(dt2)
b2 = b2.view(dt2)
......@@ -355,6 +355,11 @@ def special_add_at(a, axis, index, b):
a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
minlength=a2.shape[1])
if np.issubdtype(a.dtype, np.complexfloating):
if iscomplextype(a.dtype):
a2 = a2.view(a.dtype)
return a2.reshape(a.shape)
_iscomplex_tpl = (np.complex64, np.complex128)
def iscomplextype(dtype):
return dtype.type in _iscomplex_tpl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment