Commit 3fb9c0fb authored by Martin Reinecke's avatar Martin Reinecke
Browse files

faster type tests

parent 62d56279
...@@ -219,14 +219,14 @@ class Field(object): ...@@ -219,14 +219,14 @@ class Field(object):
@property @property
def real(self): def real(self):
"""Field : The real part of the field""" """Field : The real part of the field"""
if not np.issubdtype(self.dtype, np.complexfloating): if utilities.iscomplextype(self.dtype):
return self return Field(self._domain, self._val.real)
return Field(self._domain, self._val.real) return self
@property @property
def imag(self): def imag(self):
"""Field : The imaginary part of the field""" """Field : The imaginary part of the field"""
if not np.issubdtype(self.dtype, np.complexfloating): if not utilities.iscomplextype(self.dtype):
raise ValueError(".imag called on a non-complex Field") raise ValueError(".imag called on a non-complex Field")
return Field(self._domain, self._val.imag) return Field(self._domain, self._val.imag)
...@@ -384,7 +384,7 @@ class Field(object): ...@@ -384,7 +384,7 @@ class Field(object):
Field Field
The complex conjugated field. The complex conjugated field.
""" """
if np.issubdtype(self._val.dtype, np.complexfloating): if utilities.iscomplextype(self._val.dtype):
return Field(self._domain, self._val.conjugate()) return Field(self._domain, self._val.conjugate())
return self return self
...@@ -571,7 +571,7 @@ class Field(object): ...@@ -571,7 +571,7 @@ class Field(object):
return self._contraction_helper('var', spaces) return self._contraction_helper('var', spaces)
# MR FIXME: not very efficient or accurate # MR FIXME: not very efficient or accurate
m1 = self.mean(spaces) m1 = self.mean(spaces)
if np.issubdtype(self.dtype, np.complexfloating): if utilities.iscomplextype(self.dtype):
sq = abs(self-m1)**2 sq = abs(self-m1)**2
else: else:
sq = (self-m1)**2 sq = (self-m1)**2
......
...@@ -98,7 +98,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -98,7 +98,7 @@ class DiagonalOperator(EndomorphicOperator):
def _fill_rest(self): def _fill_rest(self):
self._ldiag.flags.writeable = False self._ldiag.flags.writeable = False
self._complex = np.issubdtype(self._ldiag.dtype, np.complexfloating) self._complex = utilities.iscomplextype(self._ldiag.dtype)
if not self._complex: if not self._complex:
lmin = self._ldiag.min() if self._ldiag.size > 0 else 1. lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()] self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
......
...@@ -80,7 +80,7 @@ class HartleyOperator(LinearOperator): ...@@ -80,7 +80,7 @@ class HartleyOperator(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if np.issubdtype(x.dtype, np.complexfloating): if utilities.iscomplextype(x.dtype):
return (self._apply_cartesian(x.real, mode) + return (self._apply_cartesian(x.real, mode) +
1j*self._apply_cartesian(x.imag, mode)) 1j*self._apply_cartesian(x.imag, mode))
else: else:
......
...@@ -84,7 +84,7 @@ class ScalingOperator(EndomorphicOperator): ...@@ -84,7 +84,7 @@ class ScalingOperator(EndomorphicOperator):
if trafo == 0: if trafo == 0:
return self return self
if trafo == ADJ and np.issubdtype(type(self._factor), np.floating): if trafo == ADJ and not np.iscomplex(self._factor):
return self return self
if trafo == ADJ: if trafo == ADJ:
return ScalingOperator(np.conj(self._factor), self._domain) return ScalingOperator(np.conj(self._factor), self._domain)
......
...@@ -87,7 +87,7 @@ class SHTOperator(LinearOperator): ...@@ -87,7 +87,7 @@ class SHTOperator(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if np.issubdtype(x.dtype, np.complexfloating): if utilities.iscomplextype(x.dtype):
return (self._apply_spherical(x.real, mode) + return (self._apply_spherical(x.real, mode) +
1j*self._apply_spherical(x.imag, mode)) 1j*self._apply_spherical(x.imag, mode))
else: else:
......
...@@ -136,7 +136,7 @@ def power_analyze(field, spaces=None, binbounds=None, ...@@ -136,7 +136,7 @@ def power_analyze(field, spaces=None, binbounds=None,
if len(spaces) == 0: if len(spaces) == 0:
raise ValueError("No space for analysis specified.") raise ValueError("No space for analysis specified.")
field_real = not np.issubdtype(field.dtype, np.complexfloating) field_real = not utilities.iscomplextype(field.dtype)
if (not field_real) and keep_phase_information: if (not field_real) and keep_phase_information:
raise ValueError("cannot keep phase from real-valued input Field") raise ValueError("cannot keep phase from real-valued input Field")
......
...@@ -30,7 +30,7 @@ from .compat import * ...@@ -30,7 +30,7 @@ from .compat import *
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space", __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c", "memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb", "my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product", "frozendict", "special_add_at"] "my_product", "frozendict", "special_add_at", "iscomplextype"]
def my_sum(iterable): def my_sum(iterable):
...@@ -212,7 +212,7 @@ def hartley(a, axes=None): ...@@ -212,7 +212,7 @@ def hartley(a, axes=None):
if axes is not None and \ if axes is not None and \
not all(axis < len(a.shape) for axis in axes): not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape") raise ValueError("Provided axes do not match array shape")
if np.issubdtype(a.dtype, np.complexfloating): if iscomplextype(a.dtype):
raise TypeError("Hartley transform requires real-valued arrays.") raise TypeError("Hartley transform requires real-valued arrays.")
from pyfftw.interfaces.numpy_fft import rfftn from pyfftw.interfaces.numpy_fft import rfftn
...@@ -256,7 +256,7 @@ def my_fftn_r2c(a, axes=None): ...@@ -256,7 +256,7 @@ def my_fftn_r2c(a, axes=None):
if axes is not None and \ if axes is not None and \
not all(axis < len(a.shape) for axis in axes): not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes do not match array shape") raise ValueError("Provided axes do not match array shape")
if np.issubdtype(a.dtype, np.complexfloating): if iscomplextype(a.dtype):
raise TypeError("Transform requires real-valued input arrays.") raise TypeError("Transform requires real-valued input arrays.")
from pyfftw.interfaces.numpy_fft import rfftn from pyfftw.interfaces.numpy_fft import rfftn
...@@ -345,7 +345,7 @@ def special_add_at(a, axis, index, b): ...@@ -345,7 +345,7 @@ def special_add_at(a, axis, index, b):
sz3 = int(np.prod(a.shape[axis+1:])) sz3 = int(np.prod(a.shape[axis+1:]))
a2 = a.reshape([sz1, -1, sz3]) a2 = a.reshape([sz1, -1, sz3])
b2 = b.reshape([sz1, -1, sz3]) b2 = b.reshape([sz1, -1, sz3])
if np.issubdtype(a.dtype, np.complexfloating): if iscomplextype(a.dtype):
dt2 = a.real.dtype dt2 = a.real.dtype
a2 = a2.view(dt2) a2 = a2.view(dt2)
b2 = b2.view(dt2) b2 = b2.view(dt2)
...@@ -355,6 +355,11 @@ def special_add_at(a, axis, index, b): ...@@ -355,6 +355,11 @@ def special_add_at(a, axis, index, b):
a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3], a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
minlength=a2.shape[1]) minlength=a2.shape[1])
if np.issubdtype(a.dtype, np.complexfloating): if iscomplextype(a.dtype):
a2 = a2.view(a.dtype) a2 = a2.view(a.dtype)
return a2.reshape(a.shape) return a2.reshape(a.shape)
_iscomplex_tpl = (np.complex64, np.complex128)
def iscomplextype(dtype):
return dtype.type in _iscomplex_tpl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment