Commit d6894d23 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

stage1

parent 766aa216
...@@ -24,16 +24,11 @@ from .version import __version__ ...@@ -24,16 +24,11 @@ from .version import __version__
from keepers import MPILogger from keepers import MPILogger
logger = MPILogger() logger = MPILogger()
# it is important to import config before d2o such that NIFTy is able to
# pre-create d2o's configuration object with the corrected path
from .config import dependency_injector,\ from .config import dependency_injector,\
nifty_configuration,\ nifty_configuration
d2o_configuration
logger.logger.setLevel(nifty_configuration['loglevel']) logger.logger.setLevel(nifty_configuration['loglevel'])
from d2o import distributed_data_object, d2o_librarian
from .field import Field from .field import Field
from .random import Random from .random import Random
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from d2o import distributed_data_object
from .field import Field from .field import Field
...@@ -32,8 +31,6 @@ def _math_helper(x, function): ...@@ -32,8 +31,6 @@ def _math_helper(x, function):
result_val = x.val.apply_scalar_function(function) result_val = x.val.apply_scalar_function(function)
result = x.copy_empty(dtype=result_val.dtype) result = x.copy_empty(dtype=result_val.dtype)
result.val = result_val result.val = result_val
elif isinstance(x, distributed_data_object):
result = x.apply_scalar_function(function, inplace=False)
else: else:
result = function(np.asarray(x)) result = function(np.asarray(x))
......
...@@ -142,7 +142,7 @@ class DomainObject(with_metaclass( ...@@ -142,7 +142,7 @@ class DomainObject(with_metaclass(
Parameters Parameters
---------- ----------
x : distributed_data_object x : numpy.ndarray
The fields data array. The fields data array.
power : int, *optional* power : int, *optional*
The power to which the volume-weight is raised (default: 1). The power to which the volume-weight is raised (default: 1).
...@@ -158,7 +158,7 @@ class DomainObject(with_metaclass( ...@@ -158,7 +158,7 @@ class DomainObject(with_metaclass(
Returns Returns
------- -------
distributed_data_object numpy.ndarray
A weighted version of x, with volume-weights raised to the A weighted version of x, with volume-weights raised to the
given power. given power.
...@@ -217,7 +217,7 @@ class DomainObject(with_metaclass( ...@@ -217,7 +217,7 @@ class DomainObject(with_metaclass(
Returns Returns
------- -------
distributed_data_object numpy.ndarray
Processed input where casting that needs Space-specific knowledge Processed input where casting that needs Space-specific knowledge
(for example location of pixels on the manifold) was performed. (for example location of pixels on the manifold) was performed.
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from __future__ import division from __future__ import division
from builtins import zip from builtins import zip
#from builtins import str
from builtins import range from builtins import range
import ast import ast
...@@ -27,9 +26,6 @@ import numpy as np ...@@ -27,9 +26,6 @@ import numpy as np
from keepers import Versionable,\ from keepers import Versionable,\
Loggable Loggable
from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES
from .config import nifty_configuration as gc from .config import nifty_configuration as gc
from .domain_object import DomainObject from .domain_object import DomainObject
...@@ -55,7 +51,7 @@ class Field(Loggable, Versionable, object): ...@@ -55,7 +51,7 @@ class Field(Loggable, Versionable, object):
LMSpace or PowerSpace. It might also be a FieldArray, which is LMSpace or PowerSpace. It might also be a FieldArray, which is
an unstructured domain. an unstructured domain.
val : scalar, numpy.ndarray, distributed_data_object, Field val : scalar, numpy.ndarray, Field
The values the array should contain after init. A scalar input will The values the array should contain after init. A scalar input will
fill the whole array with this scalar. If an array is provided the fill the whole array with this scalar. If an array is provided the
array's dimensions must match the domain's. array's dimensions must match the domain's.
...@@ -63,18 +59,11 @@ class Field(Loggable, Versionable, object): ...@@ -63,18 +59,11 @@ class Field(Loggable, Versionable, object):
dtype : type dtype : type
A numpy.type. Most common are int, float and complex. A numpy.type. Most common are int, float and complex.
distribution_strategy: optional[{'fftw', 'equal', 'not', 'freeform'}]
Specifies which distributor will be created and used.
'fftw' uses the distribution strategy of pyfftw,
'equal' tries to distribute the data as uniform as possible
'not' does not distribute the data at all
'freeform' distribute the data according to the given local data/shape
copy: boolean copy: boolean
Attributes Attributes
---------- ----------
val : distributed_data_object val : numpy.ndarray
domain : DomainObject domain : DomainObject
See Parameters. See Parameters.
...@@ -82,8 +71,6 @@ class Field(Loggable, Versionable, object): ...@@ -82,8 +71,6 @@ class Field(Loggable, Versionable, object):
Enumerates the axes of the Field Enumerates the axes of the Field
dtype : type dtype : type
Contains the datatype stored in the Field. Contains the datatype stored in the Field.
distribution_strategy : string
Name of the used distribution_strategy.
Raise Raise
----- -----
...@@ -93,38 +80,17 @@ class Field(Loggable, Versionable, object): ...@@ -93,38 +80,17 @@ class Field(Loggable, Versionable, object):
instance instance
*val is an array that has a different dimension than the domain *val is an array that has a different dimension than the domain
Examples
--------
>>> a = Field(RGSpace([4,5]),val=2)
>>> a.val
<distributed_data_object>
array([[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2]])
>>> a.dtype
dtype('int64')
See Also
--------
distributed_data_object
""" """
# ---Initialization methods--- # ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, def __init__(self, domain=None, val=None, dtype=None, copy=False):
distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
self.dtype = self._infer_dtype(dtype=dtype, self.dtype = self._infer_dtype(dtype=dtype,
val=val) val=val)
self.distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy,
val=val)
if val is None: if val is None:
self._val = None self._val = None
else: else:
...@@ -177,26 +143,10 @@ class Field(Loggable, Versionable, object): ...@@ -177,26 +143,10 @@ class Field(Loggable, Versionable, object):
return dtype return dtype
def _parse_distribution_strategy(self, distribution_strategy, val):
if distribution_strategy is None:
if isinstance(val, distributed_data_object):
distribution_strategy = val.distribution_strategy
elif isinstance(val, Field):
distribution_strategy = val.distribution_strategy
else:
self.logger.debug("distribution_strategy set to default!")
distribution_strategy = gc['default_distribution_strategy']
elif distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
raise ValueError(
"distribution_strategy must be a global-type "
"strategy.")
return distribution_strategy
# ---Factory methods--- # ---Factory methods---
@classmethod @classmethod
def from_random(cls, random_type, domain=None, dtype=None, def from_random(cls, random_type, domain=None, dtype=None, **kwargs):
distribution_strategy=None, **kwargs):
""" Draws a random field with the given parameters. """ Draws a random field with the given parameters.
Parameters Parameters
...@@ -213,9 +163,6 @@ class Field(Loggable, Versionable, object): ...@@ -213,9 +163,6 @@ class Field(Loggable, Versionable, object):
dtype : type dtype : type
The datatype of the output random field The datatype of the output random field
distribution_strategy : all supported distribution strategies
The distribution strategy of the output random field
Returns Returns
------- -------
out : Field out : Field
...@@ -229,8 +176,7 @@ class Field(Loggable, Versionable, object): ...@@ -229,8 +176,7 @@ class Field(Loggable, Versionable, object):
""" """
# create a initially empty field # create a initially empty field
f = cls(domain=domain, dtype=dtype, f = cls(domain=domain, dtype=dtype)
distribution_strategy=distribution_strategy)
# now use the processed input in terms of f in order to parse the # now use the processed input in terms of f in order to parse the
# random arguments # random arguments
...@@ -238,23 +184,14 @@ class Field(Loggable, Versionable, object): ...@@ -238,23 +184,14 @@ class Field(Loggable, Versionable, object):
f=f, f=f,
**kwargs) **kwargs)
# extract the distributed_data_object from f and apply the appropriate # extract the data from f and apply the appropriate
# random number generator to it # random number generator to it
sample = f.get_val(copy=False) sample = f.get_val(copy=False)
generator_function = getattr(Random, random_type) generator_function = getattr(Random, random_type)
comm = sample.comm sample[:]=generator_function(dtype=f.dtype,
size = comm.size shape=sample.shape,
if (sample.distribution_strategy in DISTRIBUTION_STRATEGIES['not'] and **random_arguments)
size > 1):
seed = np.random.randint(10000000)
seed = comm.bcast(seed, root=0)
np.random.seed(seed)
sample.apply_generator(
lambda shape: generator_function(dtype=f.dtype,
shape=shape,
**random_arguments))
return f return f
@staticmethod @staticmethod
...@@ -400,13 +337,8 @@ class Field(Loggable, Versionable, object): ...@@ -400,13 +337,8 @@ class Field(Loggable, Versionable, object):
# into the real and imaginary parts of the power spectrum. # into the real and imaginary parts of the power spectrum.
# If it was complex, all the power is put into a real power spectrum. # If it was complex, all the power is put into a real power spectrum.
distribution_strategy = \
work_field.val.get_axes_local_distribution_strategy(
work_field.domain_axes[space_index])
harmonic_domain = work_field.domain[space_index] harmonic_domain = work_field.domain[space_index]
power_domain = PowerSpace(harmonic_partner=harmonic_domain, power_domain = PowerSpace(harmonic_partner=harmonic_domain,
distribution_strategy=distribution_strategy,
logarithmic=logarithmic, nbin=nbin, logarithmic=logarithmic, nbin=nbin,
binbounds=binbounds) binbounds=binbounds)
power_spectrum = cls._calculate_power_spectrum( power_spectrum = cls._calculate_power_spectrum(
...@@ -421,8 +353,7 @@ class Field(Loggable, Versionable, object): ...@@ -421,8 +353,7 @@ class Field(Loggable, Versionable, object):
result_field = work_field.copy_empty( result_field = work_field.copy_empty(
domain=result_domain, domain=result_domain,
dtype=result_dtype, dtype=result_dtype)
distribution_strategy=power_spectrum.distribution_strategy)
result_field.set_val(new_val=power_spectrum, copy=False) result_field.set_val(new_val=power_spectrum, copy=False)
return result_field return result_field
...@@ -437,7 +368,6 @@ class Field(Loggable, Versionable, object): ...@@ -437,7 +368,6 @@ class Field(Loggable, Versionable, object):
pindex = cls._shape_up_pindex( pindex = cls._shape_up_pindex(
pindex=pindex, pindex=pindex,
target_shape=field_val.shape, target_shape=field_val.shape,
target_strategy=field_val.distribution_strategy,
axes=axes) axes=axes)
power_spectrum = pindex.bincount(weights=field_val, power_spectrum = pindex.bincount(weights=field_val,
...@@ -453,31 +383,18 @@ class Field(Loggable, Versionable, object): ...@@ -453,31 +383,18 @@ class Field(Loggable, Versionable, object):
@staticmethod @staticmethod
def _shape_up_pindex(pindex, target_shape, target_strategy, axes): def _shape_up_pindex(pindex, target_shape, target_strategy, axes):
if pindex.distribution_strategy not in \
DISTRIBUTION_STRATEGIES['global']:
raise ValueError("pindex's distribution strategy must be "
"global-type")
if pindex.distribution_strategy in DISTRIBUTION_STRATEGIES['slicing']:
if ((0 not in axes) or
(target_strategy is not pindex.distribution_strategy)):
raise ValueError(
"A slicing distributor shall not be reshaped to "
"something non-sliced.")
semiscaled_local_shape = [1, ] * len(target_shape) semiscaled_local_shape = [1, ] * len(target_shape)
for i in range(len(axes)): for i in range(len(axes)):
semiscaled_local_shape[axes[i]] = pindex.local_shape[i] semiscaled_local_shape[axes[i]] = pindex.local_shape[i]
local_data = pindex.get_local_data(copy=False) local_data = pindex.get_local_data(copy=False)
semiscaled_local_data = local_data.reshape(semiscaled_local_shape) semiscaled_local_data = local_data.reshape(semiscaled_local_shape)
result_obj = pindex.copy_empty(global_shape=target_shape, result_obj = pindex.copy_empty(global_shape=target_shape)
distribution_strategy=target_strategy)
result_obj.data[:] = semiscaled_local_data result_obj.data[:] = semiscaled_local_data
return result_obj return result_obj
def power_synthesize(self, spaces=None, real_power=True, real_signal=True, def power_synthesize(self, spaces=None, real_power=True, real_signal=True,
mean=None, std=None, distribution_strategy=None): mean=None, std=None):
""" Yields a sampled field with `self`**2 as its power spectrum. """ Yields a sampled field with `self`**2 as its power spectrum.
This method draws a Gaussian random field in the harmonic partner This method draws a Gaussian random field in the harmonic partner
...@@ -552,16 +469,12 @@ class Field(Loggable, Versionable, object): ...@@ -552,16 +469,12 @@ class Field(Loggable, Versionable, object):
else: else:
result_list = [None, None] result_list = [None, None]
if distribution_strategy is None:
distribution_strategy = gc['default_distribution_strategy']
result_list = [self.__class__.from_random( result_list = [self.__class__.from_random(
'normal', 'normal',
mean=mean, mean=mean,
std=std, std=std,
domain=result_domain, domain=result_domain,
dtype=np.complex, dtype=np.complex)
distribution_strategy=distribution_strategy)
for x in result_list] for x in result_list]
# from now on extract the values from the random fields for further # from now on extract the values from the random fields for further
...@@ -569,7 +482,7 @@ class Field(Loggable, Versionable, object): ...@@ -569,7 +482,7 @@ class Field(Loggable, Versionable, object):
# if the signal-space field should be real, hermitianize the field # if the signal-space field should be real, hermitianize the field
# components # components
spec = self.val.get_full_data() spec = self.val.copy()
spec = np.sqrt(spec) spec = np.sqrt(spec)
for power_space_index in spaces: for power_space_index in spaces:
...@@ -683,16 +596,6 @@ class Field(Loggable, Versionable, object): ...@@ -683,16 +596,6 @@ class Field(Loggable, Versionable, object):
# weight the random fields with the power spectrum # weight the random fields with the power spectrum
# therefore get the pindex from the power space # therefore get the pindex from the power space
pindex = power_space.pindex pindex = power_space.pindex
# take the local data from pindex. This data must be compatible to the
# local data of the field given the slice of the PowerSpace
local_distribution_strategy = \
result_list[0].val.get_axes_local_distribution_strategy(
result_list[0].domain_axes[power_space_index])
if pindex.distribution_strategy is not local_distribution_strategy:
raise AttributeError(
"The distribution_strategy of pindex does not fit the "
"slice_local distribution strategy of the synthesized field.")
# Now use numpy advanced indexing in order to put the entries of the # Now use numpy advanced indexing in order to put the entries of the
# power spectrum into the appropriate places of the pindex array. # power spectrum into the appropriate places of the pindex array.
...@@ -711,7 +614,7 @@ class Field(Loggable, Versionable, object): ...@@ -711,7 +614,7 @@ class Field(Loggable, Versionable, object):
# ---Properties--- # ---Properties---
def set_val(self, new_val=None, copy=False): def set_val(self, new_val=None, copy=False):
""" Sets the field's distributed_data_object. """ Sets the field's data object.
Parameters Parameters
---------- ----------
...@@ -736,17 +639,17 @@ class Field(Loggable, Versionable, object): ...@@ -736,17 +639,17 @@ class Field(Loggable, Versionable, object):
return self return self
def get_val(self, copy=False): def get_val(self, copy=False):
""" Returns the distributed_data_object associated with this Field. """ Returns the data object associated with this Field.
Parameters Parameters
---------- ----------
copy : boolean copy : boolean
If true, a copy of the Field's underlying distributed_data_object If true, a copy of the Field's underlying data object
is returned. is returned.
Returns Returns
------- -------
out : distributed_data_object out : numpy.ndarray
See Also See Also
-------- --------
...@@ -764,11 +667,11 @@ class Field(Loggable, Versionable, object): ...@@ -764,11 +667,11 @@ class Field(Loggable, Versionable, object):
@property @property
def val(self): def val(self):
""" Returns the distributed_data_object associated with this Field. """ Returns the data object associated with this Field.
Returns Returns
------- -------
out : distributed_data_object out : numpy.ndarray
See Also See Also
-------- --------
...@@ -874,13 +777,13 @@ class Field(Loggable, Versionable, object): ...@@ -874,13 +777,13 @@ class Field(Loggable, Versionable, object):
# ---Special unary/binary operations--- # ---Special unary/binary operations---
def cast(self, x=None, dtype=None): def cast(self, x=None, dtype=None):
""" Transforms x to a d2o with the correct dtype and shape. """ Transforms x to an object with the correct dtype and shape.
Parameters Parameters
---------- ----------
x : scalar, d2o, Field, array_like x : scalar, numpy.ndarray, Field, array_like
The input that shall be casted on a d2o of the same shape like the The input that shall be casted on a numpy.ndarray of the same shape
domain. like the domain.
dtype : type dtype : type
The datatype the output shall have. This can be used to override The datatype the output shall have. This can be used to override
...@@ -888,7 +791,7 @@ class Field(Loggable, Versionable, object): ...@@ -888,7 +791,7 @@ class Field(Loggable, Versionable, object):
Returns Returns
------- -------
out : distributed_data_object out : numpy.ndarray
The output object. The output object.
See Also See Also
...@@ -921,21 +824,17 @@ class Field(Loggable, Versionable, object): ...@@ -921,21 +824,17 @@ class Field(Loggable, Versionable, object):
if dtype is None: if dtype is None:
dtype = self.dtype dtype = self.dtype
if x is not None:
return np.asarray(x, dtype=dtype).reshape(self.shape)
else:
return np.empty(self.shape, dtype=dtype)
return_x = distributed_data_object( def copy(self, domain=None, dtype=None):
global_shape=self.shape,
dtype=dtype,
distribution_strategy=self.distribution_strategy)
return_x.set_full_data(x, copy=False)
return return_x
def copy(self, domain=None, dtype=None, distribution_strategy=None):
""" Returns a full copy of the Field. """ Returns a full copy of the Field.
If no keyword arguments are given, the returned object will be an If no keyword arguments are given, the returned object will be an
identical copy of the original Field. By explicit specification one is identical copy of the original Field. By explicit specification one is
able to define the domain, the dtype and the distribution_strategy of able to define the domain and the dtype of the returned Field.
the returned Field.
Parameters Parameters
---------- ----------
...@@ -945,9 +844,6 @@ class Field(Loggable, Versionable, object): ...@@ -945,9 +844,6 @@ class Field(Loggable, Versionable, object):
dtype : type dtype : type
The new dtype the Field shall have. The new dtype the Field shall have.
distribution_strategy : all supported distribution strategies
The new distribution strategy the Field shall have.
Returns Returns
------- -------
out : Field out : Field
...@@ -962,20 +858,18 @@ class Field(Loggable, Versionable, object): ...@@ -962,20 +858,18 @@ class Field(Loggable, Versionable, object):
copied_val = self.get_val(copy=True) copied_val = self.get_val(copy=True)
new_field = self.copy_empty( new_field = self.copy_empty(
domain=domain, domain=domain,
dtype=dtype, dtype=dtype)
distribution_strategy=distribution_strategy)
new_field.set_val(new_val=copied_val, copy=False) new_field.