Commit 24143e6a authored by Martin Reinecke's avatar Martin Reinecke

remove dtype from spaces

parent 7967f2d5
from nifty import *
#import plotly.offline as pl
#import plotly.graph_objs as go
import plotly.offline as pl
import plotly.graph_objs as go
from mpi4py import MPI
comm = MPI.COMM_WORLD
......@@ -10,10 +10,10 @@ rank = comm.rank
if __name__ == "__main__":
distribution_strategy = 'fftw'
distribution_strategy = 'not'
# Setting up the geometry
s_space = RGSpace([512, 512], dtype=np.float64)
s_space = RGSpace([512, 512])
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
......@@ -50,8 +50,7 @@ if __name__ == "__main__":
d_data = d.val.get_full_data().real
m_data = m.val.get_full_data().real
ss_data = ss.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
......@@ -53,10 +53,10 @@ class WienerFilterEnergy(Energy):
if __name__ == "__main__":
distribution_strategy = 'fftw'
distribution_strategy = 'not'
# Set up spaces and fft transformation
s_space = RGSpace([512, 512], dtype=np.float)
s_space = RGSpace([512, 512])
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
......
......@@ -27,8 +27,7 @@ from keepers import Loggable,\
class DomainObject(Versionable, Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, dtype):
self._dtype = np.dtype(dtype)
def __init__(self):
self._ignore_for_hash = []
def __hash__(self):
......@@ -50,10 +49,6 @@ class DomainObject(Versionable, Loggable, object):
def __ne__(self, x):
return not self.__eq__(x)
@property
def dtype(self):
return self._dtype
@abc.abstractproperty
def shape(self):
raise NotImplementedError(
......@@ -78,10 +73,9 @@ class DomainObject(Versionable, Loggable, object):
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype']))
result = cls()
return result
......@@ -94,8 +94,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple = (np.dtype(gc['default_field_dtype']),)
else:
dtype_tuple = (np.dtype(dtype),)
if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
......@@ -345,7 +343,7 @@ class Field(Loggable, Versionable, object):
# create random samples: one or two, depending on whether the
# power spectrum is real or complex
if issubclass(power_domain.dtype.type, np.complexfloating):
if issubclass(self.dtype.type, np.complexfloating):
result_list = [None, None]
else:
result_list = [None]
......@@ -355,7 +353,7 @@ class Field(Loggable, Versionable, object):
mean=mean,
std=std,
domain=result_domain,
dtype=harmonic_domain.dtype,
dtype=self.dtype,
distribution_strategy=self.distribution_strategy)
for x in result_list]
......@@ -403,7 +401,7 @@ class Field(Loggable, Versionable, object):
lambda x: x * local_rescaler.real,
inplace=True)
if issubclass(power_domain.dtype.type, np.complexfloating):
if issubclass(self.dtype.type, np.complexfloating):
result_val_list[1].apply_scalar_function(
lambda x: x * local_rescaler.imag,
inplace=True)
......@@ -412,7 +410,7 @@ class Field(Loggable, Versionable, object):
[x.set_val(new_val=y, copy=False) for x, y in
zip(result_list, result_val_list)]
if issubclass(power_domain.dtype.type, np.complexfloating):
if issubclass(self.dtype.type, np.complexfloating):
result = result_list[0] + 1j*result_list[1]
else:
result = result_list[0]
......
......@@ -69,7 +69,7 @@ class LMGLTransformation(SlicingTransformation):
nlat = domain.lmax + 1
nlon = domain.lmax*2 + 1
result = GLSpace(nlat=nlat, nlon=nlon, dtype=domain.dtype)
result = GLSpace(nlat=nlat, nlon=nlon)
return result
@classmethod
......
......@@ -65,7 +65,7 @@ class LMHPTransformation(SlicingTransformation):
raise TypeError("domain needs to be a LMSpace.")
nside = max((domain.lmax + 1)//2, 1)
result = HPSpace(nside=nside, dtype=domain.dtype)
result = HPSpace(nside=nside)
return result
@classmethod
......
......@@ -312,7 +312,7 @@ class FFTW(Transform):
try:
# Create return object and insert results inplace
result_dtype = np.result_type(np.complex, self.codomain.dtype)
result_dtype = np.complex
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
return_val.set_local_data(data=local_result, copy=False)
......@@ -341,7 +341,7 @@ class FFTW(Transform):
np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
)
local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
result_dtype = np.result_type(np.complex, self.codomain.dtype)
result_dtype = np.complex
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
......@@ -583,7 +583,7 @@ class NUMPYFFT(Transform):
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
result_dtype = np.result_type(np.complex, self.codomain.dtype)
result_dtype = np.complex
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
......
......@@ -44,7 +44,7 @@ class RGRGTransformation(Transformation):
raise ValueError('ERROR: unknow FFT module:' + module)
@classmethod
def get_codomain(cls, domain, dtype=None, zerocenter=None):
def get_codomain(cls, domain, zerocenter=None):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ either a shifted grid or a Fourier conjugate
......@@ -82,8 +82,7 @@ class RGRGTransformation(Transformation):
new_space = RGSpace(domain.shape,
zerocenter=zerocenter,
distances=distances,
harmonic=(not domain.harmonic),
dtype=domain.dtype)
harmonic=(not domain.harmonic))
# better safe than sorry
cls.check_codomain(domain, new_space)
......
......@@ -45,9 +45,7 @@ class Transformation(Loggable, object):
@classmethod
def check_codomain(cls, domain, codomain):
if np.dtype(domain.dtype) != np.dtype(codomain.dtype):
cls.Logger.warn("Unrecommended: domain and codomain don't have "
"the same dtype.")
pass
def transform(self, val, axes=None, **kwargs):
raise NotImplementedError
......@@ -45,8 +45,6 @@ class GLSpace(Space):
Number of latitudinal bins, or rings.
nlon : int, *optional*
Number of longitudinal bins (default: ``2*nlat - 1``).
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
See Also
--------
......@@ -62,15 +60,11 @@ class GLSpace(Space):
High-Resolution Discretization and Fast Analysis of Data
Distributed on the Sphere", *ApJ* 622..759G.
Attributes
----------
dtype : numpy.dtype
Data type of the field values.
"""
# ---Overwritten properties and methods---
def __init__(self, nlat, nlon=None, dtype=None):
def __init__(self, nlat, nlon=None):
"""
Sets the attributes for a gl_space class instance.
......@@ -80,8 +74,6 @@ class GLSpace(Space):
Number of latitudinal bins, or rings.
nlon : int, *optional*
Number of longitudinal bins (default: ``2*nlat - 1``).
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
Returns
-------
......@@ -97,7 +89,7 @@ class GLSpace(Space):
raise ImportError(
"The module pyHealpix is needed but not available.")
super(GLSpace, self).__init__(dtype)
super(GLSpace, self).__init__()
self._nlat = self._parse_nlat(nlat)
self._nlon = self._parse_nlon(nlon)
......@@ -122,8 +114,7 @@ class GLSpace(Space):
def copy(self):
return self.__class__(nlat=self.nlat,
nlon=self.nlon,
dtype=self.dtype)
nlon=self.nlon)
def weight(self, x, power=1, axes=None, inplace=False):
nlon = self.nlon
......@@ -184,7 +175,6 @@ class GLSpace(Space):
def _to_hdf5(self, hdf5_group):
hdf5_group['nlat'] = self.nlat
hdf5_group['nlon'] = self.nlon
hdf5_group.attrs['dtype'] = self.dtype.name
return None
......@@ -193,10 +183,9 @@ class GLSpace(Space):
result = cls(
nlat=hdf5_group['nlat'][()],
nlon=hdf5_group['nlon'][()],
dtype=np.dtype(hdf5_group.attrs['dtype'])
)
return result
def plot(self):
pass
\ No newline at end of file
pass
......@@ -54,15 +54,11 @@ class HPSpace(Space):
harmonic transforms revisited";
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
Attributes
----------
dtype : numpy.dtype
Data type of the field values, which is always numpy.float64.
"""
# ---Overwritten properties and methods---
def __init__(self, nside, dtype=None):
def __init__(self, nside):
"""
Sets the attributes for a hp_space class instance.
......@@ -83,7 +79,7 @@ class HPSpace(Space):
"""
super(HPSpace, self).__init__(dtype)
super(HPSpace, self).__init__()
self._nside = self._parse_nside(nside)
......@@ -106,8 +102,7 @@ class HPSpace(Space):
return 4 * np.pi
def copy(self):
return self.__class__(nside=self.nside,
dtype=self.dtype)
return self.__class__(nside=self.nside)
def weight(self, x, power=1, axes=None, inplace=False):
weight = ((4 * np.pi) / (12 * self.nside**2))**power
......@@ -142,16 +137,14 @@ class HPSpace(Space):
def _to_hdf5(self, hdf5_group):
hdf5_group['nside'] = self.nside
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(
nside=hdf5_group['nside'][()],
dtype=np.dtype(hdf5_group.attrs['dtype'])
)
return result
def plot(self):
pass
\ No newline at end of file
pass
......@@ -42,8 +42,6 @@ class LMSpace(Space):
lmax : int
Maximum :math:`\ell`-value up to which the spherical harmonics
coefficients are to be used.
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.complex128).
See Also
--------
......@@ -66,14 +64,9 @@ class LMSpace(Space):
.. [#] M. Reinecke and D. Sverre Seljebotn, 2013, "Libsharp - spherical
harmonic transforms revisited";
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
Attributes
----------
dtype : numpy.dtype
Data type of the field values.
"""
def __init__(self, lmax, dtype=None):
def __init__(self, lmax):
"""
Sets the attributes for an lm_space class instance.
......@@ -82,8 +75,6 @@ class LMSpace(Space):
lmax : int
Maximum :math:`\ell`-value up to which the spherical harmonics
coefficients are to be used.
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.complex128).
Returns
-------
......@@ -91,7 +82,7 @@ class LMSpace(Space):
"""
super(LMSpace, self).__init__(dtype)
super(LMSpace, self).__init__()
self._lmax = self._parse_lmax(lmax)
def hermitian_decomposition(self, x, axes=None,
......@@ -125,11 +116,10 @@ class LMSpace(Space):
@property
def total_volume(self):
# the individual pixels have a fixed volume of 1.
return np.float(self.dim)
return np.float64(self.dim)
def copy(self):
return self.__class__(lmax=self.lmax,
dtype=self.dtype)
return self.__class__(lmax=self.lmax)
def weight(self, x, power=1, axes=None, inplace=False):
if inplace:
......@@ -143,7 +133,7 @@ class LMSpace(Space):
dists = dists.apply_scalar_function(
lambda x: self._distance_array_helper(x, self.lmax),
dtype=np.float)
dtype=np.float64)
return dists
......@@ -178,17 +168,15 @@ class LMSpace(Space):
def _to_hdf5(self, hdf5_group):
hdf5_group['lmax'] = self.lmax
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(
lmax=hdf5_group['lmax'][()],
dtype=np.dtype(hdf5_group.attrs['dtype'])
)
return result
def plot(self):
pass
\ No newline at end of file
pass
......@@ -33,10 +33,9 @@ class PowerSpace(Space):
def __init__(self, harmonic_domain=RGSpace((1,)),
distribution_strategy='not',
log=False, nbin=None, binbounds=None,
dtype=None):
log=False, nbin=None, binbounds=None):
super(PowerSpace, self).__init__(dtype)
super(PowerSpace, self).__init__()
self._ignore_for_hash += ['_pindex', '_kindex', '_rho', '_pundex',
'_k_array']
......@@ -97,8 +96,7 @@ class PowerSpace(Space):
distribution_strategy=distribution_strategy,
log=self.log,
nbin=self.nbin,
binbounds=self.binbounds,
dtype=self.dtype)
binbounds=self.binbounds)
def weight(self, x, power=1, axes=None, inplace=False):
reshaper = [1, ] * len(x.shape)
......@@ -171,7 +169,6 @@ class PowerSpace(Space):
hdf5_group['kindex'] = self.kindex
hdf5_group['rho'] = self.rho
hdf5_group['pundex'] = self.pundex
hdf5_group.attrs['dtype'] = self.dtype.name
hdf5_group['log'] = self.log
# Store nbin as string, since it can be None
hdf5_group.attrs['nbin'] = str(self.nbin)
......@@ -190,7 +187,7 @@ class PowerSpace(Space):
# reset class
new_ps.__class__ = cls
# call instructor so that classes are properly setup
super(PowerSpace, new_ps).__init__(np.dtype(hdf5_group.attrs['dtype']))
super(PowerSpace, new_ps).__init__()
# set all values
new_ps._harmonic_domain = repository.get('harmonic_domain', hdf5_group)
new_ps._log = hdf5_group['log'][()]
......
......@@ -52,9 +52,6 @@ class RGSpace(Space):
Attributes
----------
dtype : numpy.dtype
Data type of the field values for a field defined on this space,
either ``numpy.float64`` or ``numpy.complex128``.
harmonic : bool
Whether or not the grid represents a Fourier basis.
zerocenter : {bool, numpy.ndarray}, *optional*
......@@ -67,7 +64,7 @@ class RGSpace(Space):
# ---Overwritten properties and methods---
def __init__(self, shape=(1,), zerocenter=False, distances=None,
harmonic=False, dtype=None):
harmonic=False):
"""
Sets the attributes for an rg_space class instance.
......@@ -92,13 +89,7 @@ class RGSpace(Space):
"""
self._harmonic = bool(harmonic)
if dtype is None:
if self.harmonic:
dtype = np.dtype('complex')
else:
dtype = np.dtype('float')
super(RGSpace, self).__init__(dtype)
super(RGSpace, self).__init__()
self._shape = self._parse_shape(shape)
self._distances = self._parse_distances(distances)
......@@ -198,8 +189,7 @@ class RGSpace(Space):
return self.__class__(shape=self.shape,
zerocenter=self.zerocenter,
distances=self.distances,
harmonic=self.harmonic,
dtype=self.dtype)
harmonic=self.harmonic)
def weight(self, x, power=1, axes=None, inplace=False):
weight = reduce(lambda x, y: x*y, self.distances)**power
......@@ -293,11 +283,11 @@ class RGSpace(Space):
def _parse_distances(self, distances):
if distances is None:
if self.harmonic:
temp = np.ones_like(self.shape, dtype=np.float)
temp = np.ones_like(self.shape, dtype=np.float64)
else:
temp = 1 / np.array(self.shape, dtype=np.float)
temp = 1 / np.array(self.shape, dtype=np.float64)
else:
temp = np.empty(len(self.shape), dtype=np.float)
temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances
return tuple(temp)
......@@ -313,7 +303,6 @@ class RGSpace(Space):
hdf5_group['zerocenter'] = self.zerocenter
hdf5_group['distances'] = self.distances
hdf5_group['harmonic'] = self.harmonic
hdf5_group.attrs['dtype'] = self.dtype.name
return None
......@@ -324,7 +313,6 @@ class RGSpace(Space):
zerocenter=hdf5_group['zerocenter'][:],
distances=hdf5_group['distances'][:],
harmonic=hdf5_group['harmonic'][()],
dtype=np.dtype(hdf5_group.attrs['dtype'])
)
return result
......
......@@ -147,24 +147,18 @@ from nifty.domain_object import DomainObject
class Space(DomainObject):
def __init__(self, dtype=np.dtype('float')):
def __init__(self):
"""
Parameters
----------
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
None.
Returns
-------
None.
"""
# parse dtype
casted_dtype = np.result_type(dtype, np.float64)
if casted_dtype != dtype:
self.Logger.warning("Input dtype reset to: %s" % str(casted_dtype))
super(Space, self).__init__(dtype=casted_dtype)
super(Space, self).__init__()
self._ignore_for_hash += ['_global_id']
@abc.abstractproperty
......@@ -178,7 +172,7 @@ class Space(DomainObject):
@abc.abstractmethod
def copy(self):
return self.__class__(dtype=self.dtype)
return self.__class__()
def get_distance_array(self, distribution_strategy):
raise NotImplementedError(
......@@ -195,5 +189,4 @@ class Space(DomainObject):
def __repr__(self):
string = ""
string += str(type(self)) + "\n"
string += "dtype: " + str(self.dtype) + "\n"
return string
......@@ -31,11 +31,10 @@ def create_power_operator(domain, power_spectrum, dtype=None,
domain = fft.target[0]
power_domain = PowerSpace(domain,
dtype=dtype,
distribution_strategy=distribution_strategy)
fp = Field(power_domain,
val=power_spectrum,
val=power_spectrum,dtype=dtype,
distribution_strategy=distribution_strategy)
f = fp.power_synthesize(mean=1, std=0, real_signal=False)
......
......@@ -35,14 +35,13 @@ from test.common import expand
np.random.seed(123)
SPACES = [RGSpace((4,), dtype=np.float), RGSpace((5), dtype=np.complex)]
SPACES = [RGSpace((4,)), RGSpace((5))]
SPACE_COMBINATIONS = [(), SPACES[0], SPACES[1], SPACES]
class Test_Interface(unittest.TestCase):
@expand(product(SPACE_COMBINATIONS,
[['dtype', np.dtype],
['distribution_strategy', str],
[['distribution_strategy', str],
['domain', tuple],
['domain_axes', tuple],
['val', distributed_data_object],
......
......@@ -43,6 +43,11 @@ def _harmonic_type(itp):
elif otp==np.float32:
otp=np.complex64
return otp
def _get_rtol(tp):
if (tp==np.float64) or (tp==np.complex128):
return 1e-10
else:
return 1e-5
class Misc_Tests(unittest.TestCase):
......@@ -70,12 +75,13 @@ class Misc_Tests(unittest.TestCase):
for zc2 in [False,True]:
for d in [0.1,1,3.7]:
for itp in [np.float64,np.complex128,np.float32,np.complex64]:
tol=_get_rtol(itp)
a = RGSpace(dim1, zerocenter=zc1, distances=d)
b = RGRGTransformation.get_codomain(a, zerocenter=zc2)
fft = FFTOperator(domain=a, target=b, domain_dtype=itp, target_dtype=_harmonic_type(itp))
inp = Field.from_random(domain=a,random_type='normal',std=7,mean=3,dtype=itp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.val, out.val)
assert_allclose(inp.val, out.val,rtol=tol,atol=tol)
def test_fft2D(self):
for dim1 in [10,11]:
......@@ -86,24 +92,26 @@ class Misc_Tests(unittest.TestCase):
for zc4 in [False,True]:
for d in [0.1,1,3.7]:
for itp in [np.float64,np.complex128,np.float32,np.complex64]:
tol=_get_rtol(itp)
a = RGSpace([dim1,dim2], zerocenter=[zc1,zc2], distances=d)
b = RGRGTransformation.get_codomain(a, zerocenter=[zc3,zc4])
fft = FFTOperator(domain=a, target=b, domain_dtype=itp, target_dtype=_harmonic_type(itp))
inp = Field.from_random(domain=a,random_type='normal',std=7,mean=3,dtype=itp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.val, out.val)
assert_allclose(inp.val, out.val,rtol=tol,atol=tol)
def test_sht(self):
if 'pyHealpix' not in di:
raise SkipTest
for lm in [0,3,6,11,30]:
for tp in [np.float64,np.complex128,np.float32,np.complex64]:
tol=_get_rtol(tp)
a = LMSpace(lmax=lm)
b = LMGLTransformation.get_codomain(a)
fft = FFTOperator(domain=a, target=b, domain_dtype=tp, target_dtype=tp)
inp = Field.from_random(domain=a,random_type='normal',std=7,mean=3,dtype=tp)
out = fft.inverse_times(fft.times(inp))
assert_allclose(inp.val, out.val)
assert_allclose(inp.val, out.val,rtol=tol,atol=tol)
def test_sht2(self):
if 'pyHealpix' not in di:
......
......@@ -27,18 +27,17 @@ from nifty import GLSpace
from nifty.config import dependency_injector as di
from test.common import expand
# [nlat, nlon, dtype, expected]
# [nlat, nlon, expected]
CONSTRUCTOR_CONFIGS = [
[2, None, None, {
[2, None, {
'nlat': 2,
'nlon': 3,
'harmonic': False,
'shape': (6,),
'dim': 6,
'total_volume': 4 * np.pi,
'dtype': np.dtype('float64')
'total_volume': 4 * np.pi
}],
[0, None, None, {
[0, None, {
'error': ValueError
}]
<