Commit 9ba59e8c authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'space_without_dtype_2' into 'master'

Space without dtype 2

See merge request !77
parents c9d48415 64a27773
Pipeline #11885 passed with stages
in 16 minutes and 6 seconds
...@@ -10,10 +10,10 @@ rank = comm.rank ...@@ -10,10 +10,10 @@ rank = comm.rank
if __name__ == "__main__": if __name__ == "__main__":
distribution_strategy = 'fftw' distribution_strategy = 'not'
# Setting up the geometry # Setting up the geometry
s_space = RGSpace([512, 512], dtype=np.float64) s_space = RGSpace([512, 512])
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
...@@ -54,4 +54,3 @@ if __name__ == "__main__": ...@@ -54,4 +54,3 @@ if __name__ == "__main__":
# pl.plot([go.Heatmap(z=d_data)], filename='data.html') # pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html') # pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html') # pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
...@@ -53,10 +53,10 @@ class WienerFilterEnergy(Energy): ...@@ -53,10 +53,10 @@ class WienerFilterEnergy(Energy):
if __name__ == "__main__": if __name__ == "__main__":
distribution_strategy = 'fftw' distribution_strategy = 'not'
# Set up spaces and fft transformation # Set up spaces and fft transformation
s_space = RGSpace([512, 512], dtype=np.float) s_space = RGSpace([512, 512])
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
import abc import abc
import numpy as np
from keepers import Loggable,\ from keepers import Loggable,\
Versionable Versionable
...@@ -27,8 +25,7 @@ from keepers import Loggable,\ ...@@ -27,8 +25,7 @@ from keepers import Loggable,\
class DomainObject(Versionable, Loggable, object): class DomainObject(Versionable, Loggable, object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, dtype): def __init__(self):
self._dtype = np.dtype(dtype)
self._ignore_for_hash = [] self._ignore_for_hash = []
def __hash__(self): def __hash__(self):
...@@ -50,10 +47,6 @@ class DomainObject(Versionable, Loggable, object): ...@@ -50,10 +47,6 @@ class DomainObject(Versionable, Loggable, object):
def __ne__(self, x): def __ne__(self, x):
return not self.__eq__(x) return not self.__eq__(x)
@property
def dtype(self):
return self._dtype
@abc.abstractproperty @abc.abstractproperty
def shape(self): def shape(self):
raise NotImplementedError( raise NotImplementedError(
...@@ -78,10 +71,9 @@ class DomainObject(Versionable, Loggable, object): ...@@ -78,10 +71,9 @@ class DomainObject(Versionable, Loggable, object):
# ---Serialization--- # ---Serialization---
def _to_hdf5(self, hdf5_group): def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None return None
@classmethod @classmethod
def _from_hdf5(cls, hdf5_group, repository): def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype'])) result = cls()
return result return result
...@@ -45,8 +45,7 @@ class Field(Loggable, Versionable, object): ...@@ -45,8 +45,7 @@ class Field(Loggable, Versionable, object):
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
self.dtype = self._infer_dtype(dtype=dtype, self.dtype = self._infer_dtype(dtype=dtype,
val=val, val=val)
domain=self.domain)
self.distribution_strategy = self._parse_distribution_strategy( self.distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
...@@ -86,18 +85,17 @@ class Field(Loggable, Versionable, object): ...@@ -86,18 +85,17 @@ class Field(Loggable, Versionable, object):
axes_list += [tuple(l)] axes_list += [tuple(l)]
return tuple(axes_list) return tuple(axes_list)
def _infer_dtype(self, dtype, val, domain): def _infer_dtype(self, dtype, val):
if dtype is None: if dtype is None:
if isinstance(val, Field) or \ try:
isinstance(val, distributed_data_object):
dtype = val.dtype dtype = val.dtype
dtype_tuple = (np.dtype(gc['default_field_dtype']),) except AttributeError:
if val is not None:
dtype = np.result_type(val)
else:
dtype = np.dtype(gc['default_field_dtype'])
else: else:
dtype_tuple = (np.dtype(dtype),) dtype = np.dtype(dtype)
if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype return dtype
...@@ -345,7 +343,7 @@ class Field(Loggable, Versionable, object): ...@@ -345,7 +343,7 @@ class Field(Loggable, Versionable, object):
# create random samples: one or two, depending on whether the # create random samples: one or two, depending on whether the
# power spectrum is real or complex # power spectrum is real or complex
if issubclass(power_domain.dtype.type, np.complexfloating): if issubclass(self.dtype.type, np.complexfloating):
result_list = [None, None] result_list = [None, None]
else: else:
result_list = [None] result_list = [None]
...@@ -355,7 +353,7 @@ class Field(Loggable, Versionable, object): ...@@ -355,7 +353,7 @@ class Field(Loggable, Versionable, object):
mean=mean, mean=mean,
std=std, std=std,
domain=result_domain, domain=result_domain,
dtype=harmonic_domain.dtype, dtype=self.dtype,
distribution_strategy=self.distribution_strategy) distribution_strategy=self.distribution_strategy)
for x in result_list] for x in result_list]
...@@ -403,7 +401,7 @@ class Field(Loggable, Versionable, object): ...@@ -403,7 +401,7 @@ class Field(Loggable, Versionable, object):
lambda x: x * local_rescaler.real, lambda x: x * local_rescaler.real,
inplace=True) inplace=True)
if issubclass(power_domain.dtype.type, np.complexfloating): if issubclass(self.dtype.type, np.complexfloating):
result_val_list[1].apply_scalar_function( result_val_list[1].apply_scalar_function(
lambda x: x * local_rescaler.imag, lambda x: x * local_rescaler.imag,
inplace=True) inplace=True)
...@@ -412,7 +410,7 @@ class Field(Loggable, Versionable, object): ...@@ -412,7 +410,7 @@ class Field(Loggable, Versionable, object):
[x.set_val(new_val=y, copy=False) for x, y in [x.set_val(new_val=y, copy=False) for x, y in
zip(result_list, result_val_list)] zip(result_list, result_val_list)]
if issubclass(power_domain.dtype.type, np.complexfloating): if issubclass(self.dtype.type, np.complexfloating):
result = result_list[0] + 1j*result_list[1] result = result_list[0] + 1j*result_list[1]
else: else:
result = result_list[0] result = result_list[0]
......
...@@ -83,12 +83,14 @@ class FFTOperator(LinearOperator): ...@@ -83,12 +83,14 @@ class FFTOperator(LinearOperator):
# Store the dtype information # Store the dtype information
if domain_dtype is None: if domain_dtype is None:
self.domain_dtype = None self.logger.info("Setting domain_dtype to np.float.")
self.domain_dtype = np.float
else: else:
self.domain_dtype = np.dtype(domain_dtype) self.domain_dtype = np.dtype(domain_dtype)
if target_dtype is None: if target_dtype is None:
self.target_dtype = None self.logger.info("Setting target_dtype to np.complex.")
self.target_dtype = np.complex
else: else:
self.target_dtype = np.dtype(target_dtype) self.target_dtype = np.dtype(target_dtype)
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
from nifty.config import dependency_injector as gdi from nifty.config import dependency_injector as gdi
from nifty import GLSpace, LMSpace from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation from slicing_transformation import SlicingTransformation
import lm_transformation_factory import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix') pyHealpix = gdi.get('pyHealpix')
...@@ -31,12 +31,17 @@ class GLLMTransformation(SlicingTransformation): ...@@ -31,12 +31,17 @@ class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available.") "The module pyHealpix is needed but not available.")
super(GLLMTransformation, self).__init__(domain, codomain, super(GLLMTransformation, self).__init__(domain, codomain, module)
module=module)
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
...@@ -63,7 +68,7 @@ class GLLMTransformation(SlicingTransformation): ...@@ -63,7 +68,7 @@ class GLLMTransformation(SlicingTransformation):
nlat = domain.nlat nlat = domain.nlat
lmax = nlat - 1 lmax = nlat - 1
result = LMSpace(lmax=lmax, dtype=domain.dtype) result = LMSpace(lmax=lmax)
return result return result
@classmethod @classmethod
...@@ -91,6 +96,11 @@ class GLLMTransformation(SlicingTransformation): ...@@ -91,6 +96,11 @@ class GLLMTransformation(SlicingTransformation):
super(GLLMTransformation, cls).check_codomain(domain, codomain) super(GLLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs): def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
nlat = self.domain.nlat nlat = self.domain.nlat
nlon = self.domain.nlon nlon = self.domain.nlon
lmax = self.codomain.lmax lmax = self.codomain.lmax
...@@ -104,12 +114,12 @@ class GLLMTransformation(SlicingTransformation): ...@@ -104,12 +114,12 @@ class GLLMTransformation(SlicingTransformation):
for x in (inp.real, inp.imag)] for x in (inp.real, inp.imag)]
[resultReal, [resultReal,
resultImag] = [lm_transformation_factory.buildIdx(x, lmax=lmax) resultImag] = [lm_transformation_helper.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]] for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag) result = self._combine_complex_result(resultReal, resultImag)
else: else:
result = sjob.map2alm(inp) result = sjob.map2alm(inp)
result = lm_transformation_factory.buildIdx(result, lmax=lmax) result = lm_transformation_helper.buildIdx(result, lmax=lmax)
return result return result
...@@ -22,7 +22,7 @@ from nifty.config import dependency_injector as gdi ...@@ -22,7 +22,7 @@ from nifty.config import dependency_injector as gdi
from nifty import HPSpace, LMSpace from nifty import HPSpace, LMSpace
from slicing_transformation import SlicingTransformation from slicing_transformation import SlicingTransformation
import lm_transformation_factory import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix') pyHealpix = gdi.get('pyHealpix')
...@@ -32,12 +32,17 @@ class HPLMTransformation(SlicingTransformation): ...@@ -32,12 +32,17 @@ class HPLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available") "The module pyHealpix is needed but not available")
super(HPLMTransformation, self).__init__(domain, codomain, super(HPLMTransformation, self).__init__(domain, codomain, module)
module=module)
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
...@@ -63,7 +68,7 @@ class HPLMTransformation(SlicingTransformation): ...@@ -63,7 +68,7 @@ class HPLMTransformation(SlicingTransformation):
lmax = 2*domain.nside lmax = 2*domain.nside
result = LMSpace(lmax=lmax, dtype=domain.dtype) result = LMSpace(lmax=lmax)
return result return result
@classmethod @classmethod
...@@ -83,6 +88,11 @@ class HPLMTransformation(SlicingTransformation): ...@@ -83,6 +88,11 @@ class HPLMTransformation(SlicingTransformation):
super(HPLMTransformation, cls).check_codomain(domain, codomain) super(HPLMTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs): def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
lmax = self.codomain.lmax lmax = self.codomain.lmax
mmax = lmax mmax = lmax
...@@ -92,13 +102,13 @@ class HPLMTransformation(SlicingTransformation): ...@@ -92,13 +102,13 @@ class HPLMTransformation(SlicingTransformation):
for x in (inp.real, inp.imag)] for x in (inp.real, inp.imag)]
[resultReal, [resultReal,
resultImag] = [lm_transformation_factory.buildIdx(x, lmax=lmax) resultImag] = [lm_transformation_helper.buildIdx(x, lmax=lmax)
for x in [resultReal, resultImag]] for x in [resultReal, resultImag]]
result = self._combine_complex_result(resultReal, resultImag) result = self._combine_complex_result(resultReal, resultImag)
else: else:
result = pyHealpix.map2alm_iter(inp, lmax, mmax, 3) result = pyHealpix.map2alm_iter(inp, lmax, mmax, 3)
result = lm_transformation_factory.buildIdx(result, lmax=lmax) result = lm_transformation_helper.buildIdx(result, lmax=lmax)
return result return result
...@@ -21,7 +21,7 @@ from nifty.config import dependency_injector as gdi ...@@ -21,7 +21,7 @@ from nifty.config import dependency_injector as gdi
from nifty import GLSpace, LMSpace from nifty import GLSpace, LMSpace
from slicing_transformation import SlicingTransformation from slicing_transformation import SlicingTransformation
import lm_transformation_factory import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix') pyHealpix = gdi.get('pyHealpix')
...@@ -31,12 +31,17 @@ class LMGLTransformation(SlicingTransformation): ...@@ -31,12 +31,17 @@ class LMGLTransformation(SlicingTransformation):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if 'pyHealpix' not in gdi: if 'pyHealpix' not in gdi:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available.") "The module pyHealpix is needed but not available.")
super(LMGLTransformation, self).__init__(domain, codomain, super(LMGLTransformation, self).__init__(domain, codomain, module)
module=module)
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
...@@ -69,7 +74,7 @@ class LMGLTransformation(SlicingTransformation): ...@@ -69,7 +74,7 @@ class LMGLTransformation(SlicingTransformation):
nlat = domain.lmax + 1 nlat = domain.lmax + 1
nlon = domain.lmax*2 + 1 nlon = domain.lmax*2 + 1
result = GLSpace(nlat=nlat, nlon=nlon, dtype=domain.dtype) result = GLSpace(nlat=nlat, nlon=nlon)
return result return result
@classmethod @classmethod
...@@ -97,6 +102,11 @@ class LMGLTransformation(SlicingTransformation): ...@@ -97,6 +102,11 @@ class LMGLTransformation(SlicingTransformation):
super(LMGLTransformation, cls).check_codomain(domain, codomain) super(LMGLTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs): def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
nlat = self.codomain.nlat nlat = self.codomain.nlat
nlon = self.codomain.nlon nlon = self.codomain.nlon
lmax = self.domain.lmax lmax = self.domain.lmax
...@@ -107,7 +117,7 @@ class LMGLTransformation(SlicingTransformation): ...@@ -107,7 +117,7 @@ class LMGLTransformation(SlicingTransformation):
sjob.set_triangular_alm_info(lmax, mmax) sjob.set_triangular_alm_info(lmax, mmax)
if issubclass(inp.dtype.type, np.complexfloating): if issubclass(inp.dtype.type, np.complexfloating):
[resultReal, [resultReal,
resultImag] = [lm_transformation_factory.buildLm(x, lmax=lmax) resultImag] = [lm_transformation_helper.buildLm(x, lmax=lmax)
for x in (inp.real, inp.imag)] for x in (inp.real, inp.imag)]
[resultReal, resultImag] = [sjob.alm2map(x) [resultReal, resultImag] = [sjob.alm2map(x)
...@@ -116,7 +126,7 @@ class LMGLTransformation(SlicingTransformation): ...@@ -116,7 +126,7 @@ class LMGLTransformation(SlicingTransformation):
result = self._combine_complex_result(resultReal, resultImag) result = self._combine_complex_result(resultReal, resultImag)
else: else:
result = lm_transformation_factory.buildLm(inp, lmax=lmax) result = lm_transformation_helper.buildLm(inp, lmax=lmax)
result = sjob.alm2map(result) result = sjob.alm2map(result)
return result return result
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
from nifty.config import dependency_injector as gdi from nifty.config import dependency_injector as gdi
from nifty import HPSpace, LMSpace from nifty import HPSpace, LMSpace
from slicing_transformation import SlicingTransformation from slicing_transformation import SlicingTransformation
import lm_transformation_factory import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix') pyHealpix = gdi.get('pyHealpix')
...@@ -30,12 +30,17 @@ class LMHPTransformation(SlicingTransformation): ...@@ -30,12 +30,17 @@ class LMHPTransformation(SlicingTransformation):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None): def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if gdi.get('pyHealpix') is None: if gdi.get('pyHealpix') is None:
raise ImportError( raise ImportError(
"The module pyHealpix is needed but not available.") "The module pyHealpix is needed but not available.")
super(LMHPTransformation, self).__init__(domain, codomain, super(LMHPTransformation, self).__init__(domain, codomain, module)
module=module)
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
...@@ -65,7 +70,7 @@ class LMHPTransformation(SlicingTransformation): ...@@ -65,7 +70,7 @@ class LMHPTransformation(SlicingTransformation):
raise TypeError("domain needs to be a LMSpace.") raise TypeError("domain needs to be a LMSpace.")
nside = max((domain.lmax + 1)//2, 1) nside = max((domain.lmax + 1)//2, 1)
result = HPSpace(nside=nside, dtype=domain.dtype) result = HPSpace(nside=nside)
return result return result
@classmethod @classmethod
...@@ -85,13 +90,18 @@ class LMHPTransformation(SlicingTransformation): ...@@ -85,13 +90,18 @@ class LMHPTransformation(SlicingTransformation):
super(LMHPTransformation, cls).check_codomain(domain, codomain) super(LMHPTransformation, cls).check_codomain(domain, codomain)
def _transformation_of_slice(self, inp, **kwargs): def _transformation_of_slice(self, inp, **kwargs):
if inp.dtype not in (np.float, np.complex):
self.logger.warn("The input array has dtype: %s. The FFT will "
"be performed at double precision." %
str(inp.dtype))
nside = self.codomain.nside nside = self.codomain.nside
lmax = self.domain.lmax lmax = self.domain.lmax