Commit 6c70837a authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'master' into multi_power_synth

# Conflicts:
#	nifty/field.py
parents 71512cab 9ba59e8c
......@@ -10,10 +10,10 @@ rank = comm.rank
if __name__ == "__main__":
distribution_strategy = 'fftw'
distribution_strategy = 'not'
# Setting up the geometry
s_space = RGSpace([512, 512], dtype=np.float64)
s_space = RGSpace([512, 512])
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
......@@ -54,4 +54,3 @@ if __name__ == "__main__":
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
......@@ -53,10 +53,10 @@ class WienerFilterEnergy(Energy):
if __name__ == "__main__":
distribution_strategy = 'fftw'
distribution_strategy = 'not'
# Set up spaces and fft transformation
s_space = RGSpace([512, 512], dtype=np.float)
s_space = RGSpace([512, 512])
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
......
......@@ -18,8 +18,6 @@
import abc
import numpy as np
from keepers import Loggable,\
Versionable
......@@ -27,8 +25,7 @@ from keepers import Loggable,\
class DomainObject(Versionable, Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, dtype):
self._dtype = np.dtype(dtype)
def __init__(self):
self._ignore_for_hash = []
def __hash__(self):
......@@ -50,10 +47,6 @@ class DomainObject(Versionable, Loggable, object):
def __ne__(self, x):
return not self.__eq__(x)
@property
def dtype(self):
return self._dtype
@abc.abstractproperty
def shape(self):
raise NotImplementedError(
......@@ -78,10 +71,9 @@ class DomainObject(Versionable, Loggable, object):
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype']))
result = cls()
return result
......@@ -45,8 +45,7 @@ class Field(Loggable, Versionable, object):
self.domain_axes = self._get_axes_tuple(self.domain)
self.dtype = self._infer_dtype(dtype=dtype,
val=val,
domain=self.domain)
val=val)
self.distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy,
......@@ -86,18 +85,17 @@ class Field(Loggable, Versionable, object):
axes_list += [tuple(l)]
return tuple(axes_list)
def _infer_dtype(self, dtype, val, domain):
def _infer_dtype(self, dtype, val):
if dtype is None:
if isinstance(val, Field) or \
isinstance(val, distributed_data_object):
try:
dtype = val.dtype
dtype_tuple = (np.dtype(gc['default_field_dtype']),)
except AttributeError:
if val is not None:
dtype = np.result_type(val)
else:
dtype_tuple = (np.dtype(dtype),)
if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
dtype = np.dtype(gc['default_field_dtype'])
else:
dtype = np.dtype(dtype)
return dtype
......
......@@ -83,12 +83,14 @@ class FFTOperator(LinearOperator):
# Store the dtype information
if domain_dtype is None:
self.domain_dtype = None
self.logger.info("Setting domain_dtype to np.float.")
self.domain_dtype = np.float
else:
self.domain_dtype = np.dtype(domain_dtype)
if target_dtype is None:
self.target_dtype = None
self.logger.info("Setting target_dtype to np.complex.")
self.target_dtype = np.complex
else:
self.target_dtype = np.dtype(target_dtype)
......
......@@ -21,7 +21,7 @@ import numpy as np
from nifty.config import dependency_injector as gdi
from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory
import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -31,12 +31,17 @@ class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if 'pyHealpix' not in gdi:
raise ImportError(
"The module pyHealpix is needed but not available.")
super(GLLMTransformation, self).__init__(domain, codomain,
module=module)
super(GLLMTransformation, self).__init__(domain, codomain, module)
# ---Mandatory properties and methods---
......@@ -63,7 +68,7 @@ class GLLMTransformation(SlicingTransformation):
nlat = domain.nlat
lmax = nlat - 1
result = LMSpace(lmax=lmax, dtype=domain.dtype)
result = LMSpace(lmax=lmax)
return result
@classmethod
......@@ -91,6 +96,11 @@ class GLLMTransformation(SlicingTransformation):
super(GLLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
nlat = self.domain.nlat
nlon = self.domain.nlon
lmax = self.codomain.lmax
......@@ -104,12 +114,12 @@ class GLLMTransformation(SlicingTransformation):
for x in (inp.real, inp.imag)]
[resultReal,
resultImag] = [lm_transformation_factory.buildIdx(x, lmax=lmax)
resultImag] = [lm_transformation_helper.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
else:
result = sjob.map2alm(inp)
result = lm_transformation_factory.buildIdx(result, lmax=lmax)
result = lm_transformation_helper.buildIdx(result, lmax=lmax)
return result
......@@ -22,7 +22,7 @@ from nifty.config import dependency_injector as gdi
from nifty import HPSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory
import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -32,12 +32,17 @@ class HPLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if 'pyHealpix' not in gdi:
raise ImportError(
"The module pyHealpix is needed but not available")
super(HPLMTransformation, self).__init__(domain, codomain,
module=module)
super(HPLMTransformation, self).__init__(domain, codomain, module)
# ---Mandatory properties and methods---
......@@ -63,7 +68,7 @@ class HPLMTransformation(SlicingTransformation):
lmax = 2*domain.nside
result = LMSpace(lmax=lmax, dtype=domain.dtype)
result = LMSpace(lmax=lmax)
return result
@classmethod
......@@ -83,6 +88,11 @@ class HPLMTransformation(SlicingTransformation):
super(HPLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
lmax = self.codomain.lmax
mmax = lmax
......@@ -92,13 +102,13 @@ class HPLMTransformation(SlicingTransformation):
for x in (inp.real, inp.imag)]
[resultReal,
resultImag] = [lm_transformation_factory.buildIdx(x, lmax=lmax)
resultImag] = [lm_transformation_helper.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag)
else:
result = pyHealpix.map2alm_iter(inp, lmax, mmax, 3)
result = lm_transformation_factory.buildIdx(result, lmax=lmax)
result = lm_transformation_helper.buildIdx(result, lmax=lmax)
return result
......@@ -21,7 +21,7 @@ from nifty.config import dependency_injector as gdi
from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory
import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -31,12 +31,17 @@ class LMGLTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if 'pyHealpix' not in gdi:
raise ImportError(
"The module pyHealpix is needed but not available.")
super(LMGLTransformation, self).__init__(domain, codomain,
module=module)
super(LMGLTransformation, self).__init__(domain, codomain, module)
# ---Mandatory properties and methods---
......@@ -69,7 +74,7 @@ class LMGLTransformation(SlicingTransformation):
nlat = domain.lmax + 1
nlon = domain.lmax*2 + 1
result = GLSpace(nlat=nlat, nlon=nlon, dtype=domain.dtype)
result = GLSpace(nlat=nlat, nlon=nlon)
return result
@classmethod
......@@ -97,6 +102,11 @@ class LMGLTransformation(SlicingTransformation):
super(LMGLTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
nlat = self.codomain.nlat
nlon = self.codomain.nlon
lmax = self.domain.lmax
......@@ -107,7 +117,7 @@ class LMGLTransformation(SlicingTransformation):
sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal,
resultImag] = [lm_transformation_factory.buildLm(x, lmax=lmax)
resultImag] = [lm_transformation_helper.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [sjob.alm2map(x)
......@@ -116,7 +126,7 @@ class LMGLTransformation(SlicingTransformation):
result = self._combine_complex_result(resultReal, resultImag)
else:
result = lm_transformation_factory.buildLm(inp, lmax=lmax)
result = lm_transformation_helper.buildLm(inp, lmax=lmax)
result = sjob.alm2map(result)
return result
......@@ -20,7 +20,7 @@ import numpy as np
from nifty.config import dependency_injector as gdi
from nifty import HPSpace, LMSpace
from slicing_transformation import SlicingTransformation
import lm_transformation_factory
import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
......@@ -30,12 +30,17 @@ class LMHPTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if gdi.get('pyHealpix') is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
super(LMHPTransformation, self).__init__(domain, codomain,
module=module)
super(LMHPTransformation, self).__init__(domain, codomain, module)
# ---Mandatory properties and methods---
......@@ -65,7 +70,7 @@ class LMHPTransformation(SlicingTransformation):
raise TypeError("domain needs to be a LMSpace.")
nside = max((domain.lmax + 1)//2, 1)
result = HPSpace(nside=nside, dtype=domain.dtype)
result = HPSpace(nside=nside)
return result
@classmethod
......@@ -85,13 +90,18 @@ class LMHPTransformation(SlicingTransformation):
super(LMHPTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
nside = self.codomain.nside
lmax = self.domain.lmax
mmax = lmax
if issubclass(inp.dtype.type, np.complexfloating):
[resultReal,
resultImag] = [lm_transformation_factory.buildLm(x, lmax=lmax)
resultImag] = [lm_transformation_helper.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [pyHealpix.alm2map(x, lmax, mmax, nside)
......@@ -100,7 +110,7 @@ class LMHPTransformation(SlicingTransformation):
result = self._combine_complex_result(resultReal, resultImag)
else:
result = lm_transformation_factory.buildLm(inp, lmax=lmax)
result = lm_transformation_helper.buildLm(inp, lmax=lmax)
result = pyHealpix.alm2map(result, lmax, mmax, nside)
return result
......@@ -312,9 +312,8 @@ class FFTW(Transform):
try:
# Create return object and insert results inplace
result_dtype = np.result_type(np.complex, self.codomain.dtype)
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
dtype=np.complex)
return_val.set_local_data(data=local_result, copy=False)
except(AttributeError):
return_val = local_result
......@@ -341,9 +340,8 @@ class FFTW(Transform):
np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
)
local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
result_dtype = np.result_type(np.complex, self.codomain.dtype)
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
dtype=np.complex)
# Extract local data
local_val = val.get_local_data(copy=False)
......@@ -364,7 +362,7 @@ class FFTW(Transform):
if temp_val is None:
temp_val = np.empty_like(
local_val,
dtype=result_dtype
dtype=np.complex
)
inp = local_val[slice_list]
......@@ -439,6 +437,11 @@ class FFTW(Transform):
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
if val.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(val.dtype))
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
# Cast to a np.ndarray
......@@ -583,9 +586,13 @@ class NUMPYFFT(Transform):
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
result_dtype = np.result_type(np.complex, self.codomain.dtype)
if val.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(val.dtype))
return_val = val.copy_empty(global_shape=val.shape,
dtype=result_dtype)
dtype=np.complex)
if (axes is None) or (0 in axes) or \
(val.distribution_strategy not in STRATEGIES['slicing']):
......
......@@ -24,8 +24,7 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain,
module=module)
super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None:
if nifty_configuration['fft_module'] == 'fftw':
......@@ -33,7 +32,7 @@ class RGRGTransformation(Transformation):
elif nifty_configuration['fft_module'] == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
else:
raise ValueError('ERROR: unknow default FFT module:' +
raise ValueError('Unsupported default FFT module:' +
nifty_configuration['fft_module'])
else:
if module == 'fftw':
......@@ -41,10 +40,10 @@ class RGRGTransformation(Transformation):
elif module == 'numpy':
self._transform = NUMPYFFT(self.domain, self.codomain)
else:
raise ValueError('ERROR: unknow FFT module:' + module)
raise ValueError('Unsupported FFT module:' + module)
@classmethod
def get_codomain(cls, domain, dtype=None, zerocenter=None):
def get_codomain(cls, domain, zerocenter=None):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ either a shifted grid or a Fourier conjugate
......@@ -82,8 +81,7 @@ class RGRGTransformation(Transformation):
new_space = RGSpace(domain.shape,
zerocenter=zerocenter,
distances=distances,
harmonic=(not domain.harmonic),
dtype=domain.dtype)
harmonic=(not domain.harmonic))
# better safe than sorry
cls.check_codomain(domain, new_space)
......
......@@ -18,8 +18,6 @@
import abc
import numpy as np
from keepers import Loggable
......@@ -40,14 +38,12 @@ class Transformation(Loggable, object):
self.codomain = codomain
@classmethod
def get_codomain(cls, domain, dtype=None, zerocenter=None):
def get_codomain(cls, domain):
raise NotImplementedError
@classmethod
def check_codomain(cls, domain, codomain):
if np.dtype(domain.dtype) != np.dtype(codomain.dtype):
cls.Logger.warn("Unrecommended: domain and codomain don't have "
"the same dtype.")
pass
def transform(self, val, axes=None, **kwargs):
raise NotImplementedError
......@@ -45,8 +45,6 @@ class GLSpace(Space):
Number of latitudinal bins, or rings.
nlon : int, *optional*
Number of longitudinal bins (default: ``2*nlat - 1``).
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
See Also
--------
......@@ -62,15 +60,11 @@ class GLSpace(Space):
High-Resolution Discretization and Fast Analysis of Data
Distributed on the Sphere", *ApJ* 622..759G.
Attributes
----------
dtype : numpy.dtype
Data type of the field values.
"""
# ---Overwritten properties and methods---
def __init__(self, nlat, nlon=None, dtype=None):
def __init__(self, nlat, nlon=None):
"""
Sets the attributes for a gl_space class instance.
......@@ -80,8 +74,6 @@ class GLSpace(Space):
Number of latitudinal bins, or rings.
nlon : int, *optional*
Number of longitudinal bins (default: ``2*nlat - 1``).
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
Returns
-------
......@@ -97,7 +89,7 @@ class GLSpace(Space):
raise ImportError(
"The module pyHealpix is needed but not available.")
super(GLSpace, self).__init__(dtype)
super(GLSpace, self).__init__()
self._nlat = self._parse_nlat(nlat)
self._nlon = self._parse_nlon(nlon)
......@@ -122,8 +114,7 @@ class GLSpace(Space):
def copy(self):
return self.__class__(nlat=self.nlat,
nlon=self.nlon,
dtype=self.dtype)
nlon=self.nlon)
def weight(self, x, power=1, axes=None, inplace=False):
nlon = self.nlon
......@@ -184,7 +175,6 @@ class GLSpace(Space):
def _to_hdf5(self, hdf5_group):
hdf5_group['nlat'] = self.nlat
hdf5_group['nlon'] = self.nlon
hdf5_group.attrs['dtype'] = self.dtype.name
return None
......@@ -193,10 +183,6 @@ class GLSpace(Space):
result = cls(
nlat=hdf5_group['nlat'][()],
nlon=hdf5_group['nlon'][()],
dtype=np.dtype(hdf5_group.attrs['dtype'])
)
return result
def plot(self):
pass
\ No newline at end of file
......@@ -54,15 +54,11 @@ class HPSpace(Space):
harmonic transforms revisited";
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
Attributes
----------
dtype : numpy.dtype
Data type of the field values, which is always numpy.float64.
"""
# ---Overwritten properties and methods---
def __init__(self, nside, dtype=None):
def __init__(self, nside):