Commit 0b5464f3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first working version

parent e7bc9096
...@@ -19,5 +19,3 @@ ...@@ -19,5 +19,3 @@
from .nifty_config import dependency_injector,\ from .nifty_config import dependency_injector,\
nifty_configuration nifty_configuration
from .d2o_config import d2o_configuration
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import os
import keepers
# pre-create the D2O configuration instance and set its path explicitly
d2o_configuration = keepers.get_Configuration(
name='D2O',
file_name='D2O.conf',
search_paths=[os.path.expanduser('~') + "/.config/nifty/",
os.path.expanduser('~') + "/.config/",
'./'])
...@@ -34,8 +34,6 @@ dependency_injector = keepers.DependencyInjector( ...@@ -34,8 +34,6 @@ dependency_injector = keepers.DependencyInjector(
def _fft_module_checker(z): def _fft_module_checker(z):
if z == 'fftw_mpi':
return hasattr(dependency_injector.get('fftw'), 'FFTW_MPI')
if z == 'fftw': if z == 'fftw':
return ('fftw' in dependency_injector) return ('fftw' in dependency_injector)
if z == 'numpy': if z == 'numpy':
...@@ -45,7 +43,7 @@ def _fft_module_checker(z): ...@@ -45,7 +43,7 @@ def _fft_module_checker(z):
# Initialize the variables # Initialize the variables
variable_fft_module = keepers.Variable( variable_fft_module = keepers.Variable(
'fft_module', 'fft_module',
['fftw_mpi', 'fftw', 'numpy'], ['fftw', 'numpy'],
_fft_module_checker) _fft_module_checker)
......
...@@ -370,7 +370,7 @@ class Field(Loggable, Versionable, object): ...@@ -370,7 +370,7 @@ class Field(Loggable, Versionable, object):
target_shape=field_val.shape, target_shape=field_val.shape,
axes=axes) axes=axes)
power_spectrum = pindex.bincount(weights=field_val, power_spectrum = utilities.bincount_axis(pindex, weights=field_val,
axis=axes) axis=axes)
rho = pdomain.rho rho = pdomain.rho
if axes is not None: if axes is not None:
...@@ -382,14 +382,14 @@ class Field(Loggable, Versionable, object): ...@@ -382,14 +382,14 @@ class Field(Loggable, Versionable, object):
return power_spectrum return power_spectrum
@staticmethod @staticmethod
def _shape_up_pindex(pindex, target_shape, target_strategy, axes): def _shape_up_pindex(pindex, target_shape, axes):
semiscaled_local_shape = [1, ] * len(target_shape) semiscaled_local_shape = [1, ] * len(target_shape)
for i in range(len(axes)): for i in range(len(axes)):
semiscaled_local_shape[axes[i]] = pindex.local_shape[i] semiscaled_local_shape[axes[i]] = pindex.shape[i]
local_data = pindex.get_local_data(copy=False) local_data = pindex
semiscaled_local_data = local_data.reshape(semiscaled_local_shape) semiscaled_local_data = local_data.reshape(semiscaled_local_shape)
result_obj = pindex.copy_empty(global_shape=target_shape) result_obj = np.empty(target_shape, dtype=pindex.dtype)
result_obj.data[:] = semiscaled_local_data result_obj[:] = semiscaled_local_data
return result_obj return result_obj
...@@ -492,14 +492,10 @@ class Field(Loggable, Versionable, object): ...@@ -492,14 +492,10 @@ class Field(Loggable, Versionable, object):
result_val_list = [x.val for x in result_list] result_val_list = [x.val for x in result_list]
# apply the rescaler to the random fields # apply the rescaler to the random fields
result_val_list[0].apply_scalar_function( result_val_list[0] *= local_rescaler.real
lambda x: x * local_rescaler.real,
inplace=True)
if not real_power: if not real_power:
result_val_list[1].apply_scalar_function( result_val_list[1] *= local_rescaler.imag
lambda x: x * local_rescaler.imag,
inplace=True)
if real_signal: if real_signal:
result_val_list = [self._hermitian_decomposition( result_val_list = [self._hermitian_decomposition(
...@@ -538,8 +534,8 @@ class Field(Loggable, Versionable, object): ...@@ -538,8 +534,8 @@ class Field(Loggable, Versionable, object):
# no flips are applied, one can use `is` to infer this case. # no flips are applied, one can use `is` to infer this case.
if flipped_val is val: if flipped_val is val:
h = flipped_val.real h = flipped_val.real.copy()
a = 1j * flipped_val.imag a = 1j * flipped_val.imag.copy()
else: else:
flipped_val = flipped_val.conjugate() flipped_val = flipped_val.conjugate()
h = (val + flipped_val)/2. h = (val + flipped_val)/2.
...@@ -600,7 +596,7 @@ class Field(Loggable, Versionable, object): ...@@ -600,7 +596,7 @@ class Field(Loggable, Versionable, object):
# Now use numpy advanced indexing in order to put the entries of the # Now use numpy advanced indexing in order to put the entries of the
# power spectrum into the appropriate places of the pindex array. # power spectrum into the appropriate places of the pindex array.
# Do this for every 'pindex-slice' in parallel using the 'slice(None)'s # Do this for every 'pindex-slice' in parallel using the 'slice(None)'s
local_pindex = pindex.get_local_data(copy=False) local_pindex = pindex
local_blow_up = [slice(None)]*len(spec.shape) local_blow_up = [slice(None)]*len(spec.shape)
# it is important to count from behind, since spec potentially grows # it is important to count from behind, since spec potentially grows
......
...@@ -20,7 +20,7 @@ from builtins import next ...@@ -20,7 +20,7 @@ from builtins import next
from builtins import range from builtins import range
import numpy as np import numpy as np
from itertools import product from itertools import product
import itertools
def get_slice_list(shape, axes): def get_slice_list(shape, axes):
""" """
...@@ -110,3 +110,118 @@ def parse_domain(domain): ...@@ -110,3 +110,118 @@ def parse_domain(domain):
"Given object contains something that is not an " "Given object contains something that is not an "
"instance of DomainObject-class.") "instance of DomainObject-class.")
return domain return domain
def slicing_generator(shape, axes):
"""
Helper function which generates slice list(s) to traverse over all
combinations of axes, other than the selected axes.
Parameters
----------
shape: tuple
Shape of the data array to traverse over.
axes: tuple
Axes which should not be iterated over.
Yields
-------
list
The next list of indices and/or slice objects for each dimension.
Raises
------
ValueError
If shape is empty.
ValueError
If axes(axis) does not match shape.
"""
if not shape:
raise ValueError("ERROR: shape cannot be None.")
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError("ERROR: axes(axis) does not match shape.")
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\
[list(range(y)) for x, y in enumerate(shape) if x not in axes]
for current_index in itertools.product(*axes_iterables):
it_iter = iter(current_index)
slice_list = [next(it_iter) if use_axis else
slice(None, None) for use_axis in axes_select]
yield slice_list
else:
yield [slice(None, None)]
return
def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None:
length = max(np.amax(obj) + 1, minlength)
else:
length = np.amax(obj) + 1
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj
# if present, parse the axis keyword and transpose/reorder self.data
# such that all affected axes follow each other. Only if they are in a
# sequence flattening will be possible
if axis is not None:
# do the reordering
ndim = len(obj.shape)
axis = sorted(cast_axis_to_tuple(axis, length=ndim))
reordering = [x for x in range(ndim) if x not in axis]
reordering += axis
data = np.transpose(data, reordering)
if weights is not None:
weights = np.transpose(weights, reordering)
reord_axis = list(range(ndim-len(axis), ndim))
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
semi_flat_dim = reduce(lambda x, y: x*y,
data.shape[ndim-len(reord_axis):])
flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
else:
flat_shape = (reduce(lambda x, y: x*y, data.shape), )
data = np.ascontiguousarray(data.reshape(flat_shape))
if weights is not None:
weights = np.ascontiguousarray(
weights.reshape(flat_shape))
# compute the local bincount results
# -> prepare the local result array
if weights is None:
result_dtype = np.int
else:
result_dtype = np.float
local_counts = np.empty(flat_shape[:-1] + (length, ),
dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for slice_list in slicing_generator(flat_shape,
axes=(len(flat_shape)-1, )):
if weights is not None:
current_weights = weights[slice_list]
else:
current_weights = None
local_counts[slice_list] = np.bincount(
data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (list(range(0, insert_position)) +
[new_ndim-1, ] +
list(range(insert_position, new_ndim-1)))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return local_counts
...@@ -37,7 +37,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -37,7 +37,7 @@ class DiagonalOperator(EndomorphicOperator):
---------- ----------
domain : tuple of DomainObjects, i.e. Spaces and FieldTypes domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
The domain on which the Operator's input Field lives. The domain on which the Operator's input Field lives.
diagonal : {scalar, list, array, Field, d2o-object} diagonal : {scalar, list, array, Field}
The diagonal entries of the operator. The diagonal entries of the operator.
bare : boolean bare : boolean
Indicates whether the input for the diagonal is bare or not Indicates whether the input for the diagonal is bare or not
...@@ -181,7 +181,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -181,7 +181,7 @@ class DiagonalOperator(EndomorphicOperator):
Parameters Parameters
---------- ----------
diagonal : {scalar, list, array, Field, d2o-object} diagonal : {scalar, list, array, Field}
The diagonal entries of the operator. The diagonal entries of the operator.
bare : boolean bare : boolean
Indicates whether the input for the diagonal is bare or not Indicates whether the input for the diagonal is bare or not
...@@ -226,16 +226,15 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -226,16 +226,15 @@ class DiagonalOperator(EndomorphicOperator):
for space_index in spaces: for space_index in spaces:
active_axes += x.domain_axes[space_index] active_axes += x.domain_axes[space_index]
local_diagonal = self._diagonal.val.get_local_data(copy=False) local_diagonal = self._diagonal.val
reshaper = [x.val.data.shape[i] if i in active_axes else 1 reshaper = [x.val.shape[i] if i in active_axes else 1
for i in range(len(x.shape))] for i in range(len(x.shape))]
reshaped_local_diagonal = np.reshape(local_diagonal, reshaper) reshaped_local_diagonal = np.reshape(local_diagonal, reshaper)
# here the actual multiplication takes place # here the actual multiplication takes place
local_result = operation(reshaped_local_diagonal)( local_result = operation(reshaped_local_diagonal)(x.val)
x.val.get_local_data(copy=False))
result_field = x.copy_empty(dtype=local_result.dtype) result_field = x.copy_empty(dtype=local_result.dtype)
result_field.val.set_local_data(local_result, copy=False) result_field.val=local_result
return result_field return result_field
...@@ -203,324 +203,6 @@ class Transform(Loggable, object): ...@@ -203,324 +203,6 @@ class Transform(Loggable, object):
raise NotImplementedError raise NotImplementedError
class MPIFFT(Transform):
"""
The MPI-parallel FFTW pendant of a fft object.
"""
def __init__(self, domain, codomain):
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
super(MPIFFT, self).__init__(domain, codomain)
# Enable caching
fftw.interfaces.cache.enable()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
self.info_dict = {}
def _get_transform_info(self, domain, codomain, axes, local_shape,
local_offset_Q, is_local, transform_shape=None,
**kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain, codomain, transform_shape, is_local)
# generate the plan_and_info object if not already there
if temp_id not in self.info_dict:
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, axes, local_shape,
local_offset_Q, self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, axes, local_shape,
local_offset_Q, self, transform_shape, **kwargs
)
return self.info_dict[temp_id]
def _atomic_mpi_transform(self, val, info, axes):
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, info.cmask_codomain, axes)
p = info.plan
# Load the value into the plan
if p.has_input:
try:
p.input_array[None] = val
except ValueError:
raise ValueError("Failed to load data into input_array of "
"FFTW MPI-plan. Maybe the 1D slicing differs"
"from n-D slicing?")
# Execute the plan
p()
if p.has_output:
result = p.output_array.copy()
if result.shape != val.shape:
raise ValueError("Output shape is different than input shape. "
"Maybe fftw tries to optimize the "
"bit-alignment? Try a different array-size.")
else:
return None
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
result = self._apply_mask(result, info.cmask_domain, axes)
# Correct the sign if needed
result *= info.sign
return result
def _local_transform(self, val, axes, **kwargs):
####
# val must be numpy array or d2o with slicing distributor
###
try:
local_val = val.get_local_data(copy=False)
except(AttributeError):
local_val = val
current_info = self._get_transform_info(self.domain,
self.codomain,
axes,
local_shape=local_val.shape,
local_offset_Q=False,
is_local=True,
**kwargs)
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(local_val)
local_val = self._apply_mask(temp_val,
current_info.cmask_codomain, axes)
local_result = current_info.fftw_interface(
local_val,
axes=axes,
planner_effort='FFTW_ESTIMATE'
)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
local_result = self._apply_mask(local_result,
current_info.cmask_domain, axes)
# Correct the sign if needed
if current_info.sign != 1:
local_result *= current_info.sign
try:
# Create return object and insert results inplace
return_val = val.copy_empty(global_shape=val.shape,
dtype=np.complex)
return_val.set_local_data(data=local_result, copy=False)
except(AttributeError):
return_val = local_result
return return_val
def _repack_to_fftw_and_transform(self, val, axes, **kwargs):
temp_val = val.copy_empty()
temp_val.set_full_data(val, copy=False)
# Recursive call to transform
result = self.transform(temp_val, axes, **kwargs)
return_val = result.copy_empty()
return_val.set_full_data(data=result, copy=False)
return return_val
def _mpi_transform(self, val, axes, **kwargs):
local_offset_list = np.cumsum(
np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
)
local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
return_val = val.copy_empty(global_shape=val.shape,
dtype=np.complex)
# Extract local data
local_val = val.get_local_data(copy=False)
# Create temporary storage for slices
temp_val = None
# If axes tuple includes all axes, set it to None
if axes is not None:
if set(axes) == set(range(len(val.shape))):
axes = None
current_info = None
for slice_list in utilities.get_slice_list(local_val.shape, axes):
if slice_list == [slice(None, None)]:
inp = local_val
else:
if temp_val is None:
temp_val = np.empty_like(
local_val,
dtype=np.complex
)
inp = local_val[slice_list]
if current_info is None:
transform_shape = list(inp.shape)
transform_shape[0] = val.shape[0]
current_info = self._get_transform_info(
self.domain,
self.codomain,
axes,
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
transform_shape=tuple(transform_shape),
**kwargs
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
result = self._atomic_mpi_transform(inp, current_info, axes)
if result is None:
temp_val = np.empty_like(local_val)
elif slice_list == [slice(None, None)]:
temp_val = result
else:
temp_val[slice_list] = result
return_val.set_local_data(data=temp_val, copy=False)
return return_val
def transform(self, val, axes, **kwargs):
"""
The MPI-parallel FFTW transform function.
Parameters
----------
val : distributed_data_object or numpy.ndarray
The value-array of the field which is supposed to
be transformed.
axes: tuple, None
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are passed to the create_mpi_plan routine.
Returns
-------
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
# If the input is a numpy array we transform it locally
temp_val = np.asarray(val)
# local transform doesn't apply transforms inplace
return_val = self._local_transform(temp_val, axes)
return return_val
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, axes, local_shape,
local_offset_Q, fftw_context, **kwargs):
if not hasattr(fftw, 'FFTW_MPI'):
raise ImportError(
"The MPI FFTW module is needed but not available.")
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
self._cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
shape,
local_offset_Q)