Commit 803b053a authored by Ultima's avatar Ultima
Browse files

nifty_mpi_data.py is now PEP8 compliant.

parent d7045735
# -*- coding: utf-8 -*-
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Theo Steininger
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.
# NIFTY (Numerical Information Field Theory) has been developed at the
# Max-Planck-Institute for Astrophysics.
#
# Copyright (C) 2015 Max-Planck-Society
#
# Author: Theo Steininger
# Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
......@@ -101,6 +101,7 @@ class distributed_data_object(object):
If the supplied distribution strategy is not known.
"""
def __init__(self, global_data=None, global_shape=None, dtype=None,
local_data=None, local_shape=None,
distribution_strategy='fftw', hermitian=False,
......@@ -115,15 +116,15 @@ class distributed_data_object(object):
local_data = np.array(local_data, copy=False)
self.distributor = distributor_factory.get_distributor(
distribution_strategy = distribution_strategy,
comm = comm,
global_data = global_data,
global_shape = global_shape,
local_data = local_data,
local_shape = local_shape,
alias = alias,
path = path,
dtype = dtype,
distribution_strategy=distribution_strategy,
comm=comm,
global_data=global_data,
global_shape=global_shape,
local_data=local_data,
local_shape=local_shape,
alias=alias,
path=path,
dtype=dtype,
**kwargs)
self.distribution_strategy = distribution_strategy
......@@ -136,12 +137,12 @@ class distributed_data_object(object):
self.init_kwargs = kwargs
(self.data, self.hermitian) = self.distributor.initialize_data(
global_data = global_data,
local_data = local_data,
alias = alias,
path = path,
hermitian = hermitian,
copy = copy)
global_data=global_data,
local_data=local_data,
alias=alias,
path=path,
hermitian=hermitian,
copy=copy)
self.index = d2o_librarian.register(self)
def copy(self, dtype=None, distribution_strategy=None, **kwargs):
......@@ -159,14 +160,14 @@ class distributed_data_object(object):
def copy_empty(self, global_shape=None, local_shape=None, dtype=None,
distribution_strategy=None, **kwargs):
if self.distribution_strategy == 'not' and \
distribution_strategy in LOCAL_DISTRIBUTION_STRATEGIES and \
local_shape == None:
result = self.copy_empty(global_shape = global_shape,
local_shape = local_shape,
dtype = dtype,
distribution_strategy = 'equal',
distribution_strategy in LOCAL_DISTRIBUTION_STRATEGIES and \
local_shape is None:
result = self.copy_empty(global_shape=global_shape,
local_shape=local_shape,
dtype=dtype,
distribution_strategy='equal',
**kwargs)
return result.copy_empty(distribution_strategy = 'freeform')
return result.copy_empty(distribution_strategy='freeform')
if global_shape is None:
global_shape = self.shape
......@@ -179,19 +180,20 @@ class distributed_data_object(object):
kwargs.update(self.init_kwargs)
temp_d2o = distributed_data_object(global_shape=global_shape,
local_shape = local_shape,
dtype = dtype,
distribution_strategy = distribution_strategy,
comm = self.comm,
*self.init_args,
**kwargs)
temp_d2o = distributed_data_object(
global_shape=global_shape,
local_shape=local_shape,
dtype=dtype,
distribution_strategy=distribution_strategy,
comm=self.comm,
*self.init_args,
**kwargs)
return temp_d2o
def apply_scalar_function(self, function, inplace=False, dtype=None):
remember_hermitianQ = self.hermitian
if inplace == True:
if inplace is True:
temp = self
if dtype is not None and self.dtype != np.dtype(dtype):
about.warnings.cprint(
......@@ -224,22 +226,21 @@ class distributed_data_object(object):
return self.data.__str__()
def __repr__(self):
return '<distributed_data_object>\n'+self.data.__repr__()
return '<distributed_data_object>\n' + self.data.__repr__()
def _compare_helper(self, other, op):
result = self.copy_empty(dtype = np.bool_)
result = self.copy_empty(dtype=np.bool_)
# Case 1: 'other' is a scalar
# -> make point-wise comparison
if np.isscalar(other):
result.set_local_data(
getattr(self.get_local_data(copy = False), op)(other))
getattr(self.get_local_data(copy=False), op)(other))
return result
# Case 2: 'other' is a numpy array or a distributed_data_object
# -> extract the local data and make point-wise comparison
elif isinstance(other, np.ndarray) or\
isinstance(other, distributed_data_object):
isinstance(other, distributed_data_object):
temp_data = self.distributor.extract_local_data(other)
result.set_local_data(
getattr(self.get_local_data(copy=False), op)(temp_data))
......@@ -255,7 +256,6 @@ class distributed_data_object(object):
temp_other = np.array(other)
return getattr(self, op)(temp_other)
def __ne__(self, other):
return self._compare_helper(other, '__ne__')
......@@ -268,6 +268,7 @@ class distributed_data_object(object):
def __eq__(self, other):
return self._compare_helper(other, '__eq__')
def __ge__(self, other):
return self._compare_helper(other, '__ge__')
......@@ -289,18 +290,15 @@ class distributed_data_object(object):
else:
return True
def __pos__(self):
temp_d2o = self.copy_empty()
temp_d2o.set_local_data(data = self.get_local_data(), copy = True)
temp_d2o.set_local_data(data=self.get_local_data(), copy=True)
return temp_d2o
def __neg__(self):
temp_d2o = self.copy_empty()
temp_d2o.set_local_data(data = self.get_local_data().__neg__(),
copy = True)
temp_d2o.set_local_data(data=self.get_local_data().__neg__(),
copy=True)
return temp_d2o
def __abs__(self):
......@@ -313,9 +311,9 @@ class distributed_data_object(object):
new_dtype = np.dtype('float')
else:
new_dtype = self.dtype
temp_d2o = self.copy_empty(dtype = new_dtype)
temp_d2o.set_local_data(data = self.get_local_data().__abs__(),
copy = True)
temp_d2o = self.copy_empty(dtype=new_dtype)
temp_d2o.set_local_data(data=self.get_local_data().__abs__(),
copy=True)
return temp_d2o
def _builtin_helper(self, operator, other, inplace=False):
......@@ -344,14 +342,14 @@ class distributed_data_object(object):
hermitian_Q = False
temp_data = operator(other)
# write the new data into a new distributed_data_object
if inplace == True:
if inplace is True:
temp_d2o = self
else:
# use common datatype for self and other
new_dtype = np.dtype(np.find_common_type((self.dtype,),
(temp_data.dtype,)))
temp_d2o = self.copy_empty(
dtype = new_dtype)
dtype=new_dtype)
temp_d2o.set_local_data(data=temp_data)
temp_d2o.hermitian = hermitian_Q
return temp_d2o
......@@ -364,8 +362,8 @@ class distributed_data_object(object):
def __iadd__(self, other):
return self._builtin_helper(self.get_local_data().__iadd__,
other,
inplace = True)
other,
inplace=True)
def __sub__(self, other):
return self._builtin_helper(self.get_local_data().__sub__, other)
......@@ -375,8 +373,8 @@ class distributed_data_object(object):
def __isub__(self, other):
return self._builtin_helper(self.get_local_data().__isub__,
other,
inplace = True)
other,
inplace=True)
def __div__(self, other):
return self._builtin_helper(self.get_local_data().__div__, other)
......@@ -392,21 +390,24 @@ class distributed_data_object(object):
def __idiv__(self, other):
return self._builtin_helper(self.get_local_data().__idiv__,
other,
inplace = True)
other,
inplace=True)
def __itruediv__(self, other):
return self.__idiv__(other)
def __floordiv__(self, other):
return self._builtin_helper(self.get_local_data().__floordiv__,
other)
other)
def __rfloordiv__(self, other):
return self._builtin_helper(self.get_local_data().__rfloordiv__,
other)
other)
def __ifloordiv__(self, other):
return self._builtin_helper(
self.get_local_data().__ifloordiv__, other,
inplace = True)
self.get_local_data().__ifloordiv__, other,
inplace=True)
def __mul__(self, other):
return self._builtin_helper(self.get_local_data().__mul__, other)
......@@ -416,8 +417,8 @@ class distributed_data_object(object):
def __imul__(self, other):
return self._builtin_helper(self.get_local_data().__imul__,
other,
inplace = True)
other,
inplace=True)
def __pow__(self, other):
return self._builtin_helper(self.get_local_data().__pow__, other)
......@@ -427,16 +428,20 @@ class distributed_data_object(object):
def __ipow__(self, other):
return self._builtin_helper(self.get_local_data().__ipow__,
other,
inplace = True)
other,
inplace=True)
def __mod__(self, other):
return self._builtin_helper(self.get_local_data().__mod__, other)
def __rmod__(self, other):
return self._builtin_helper(self.get_local_data().__rmod__, other)
def __imod__(self, other):
return self._builtin_helper(self.get_local_data().__imod__,
other,
inplace = True)
other,
inplace=True)
def __len__(self):
return self.shape[0]
......@@ -450,12 +455,9 @@ class distributed_data_object(object):
global_vdot = np.sum(local_vdot_list)
return global_vdot
def __getitem__(self, key):
return self.get_data(key)
def __setitem__(self, key, data):
self.set_data(data, key)
......@@ -468,11 +470,11 @@ class distributed_data_object(object):
include = True
local_list = self.distributor._allgather(local)
local_list = np.array(local_list, dtype = np.dtype(local_list[0]))
local_list = np.array(local_list, dtype=np.dtype(local_list[0]))
include_list = np.array(self.distributor._allgather(include))
work_list = local_list[include_list]
if work_list.shape[0] == 0:
raise ValueError("ERROR: Zero-size array to reduction operation "+
raise ValueError("ERROR: Zero-size array to reduction operation " +
"which has no identity")
else:
result = function(work_list, axis=0)
......@@ -510,8 +512,8 @@ class distributed_data_object(object):
local_mean_list = self.distributor._allgather(local_mean)
local_weight_list = self.distributor._allgather(local_weight)
local_mean_list =np.array(local_mean_list,
dtype = np.dtype(local_mean_list[0]))
local_mean_list = np.array(local_mean_list,
dtype=np.dtype(local_mean_list[0]))
local_weight_list = np.array(local_weight_list)
# extract the parts from the non-empty nodes
include_list = np.array(self.distributor._allgather(include))
......@@ -524,7 +526,7 @@ class distributed_data_object(object):
global_weight = np.sum(work_weight_list)
# compute the numerator
numerator = np.sum(work_mean_list * work_weight_list)
global_mean = numerator/global_weight
global_mean = numerator / global_weight
return global_mean
def var(self):
......@@ -546,12 +548,13 @@ class distributed_data_object(object):
self.data.shape)]
globalized_local_argmin = self.distributor.globalize_flat_index(
local_argmin)
local_argmin_list = self.distributor._allgather((local_argmin_value,
globalized_local_argmin))
local_argmin)
local_argmin_list = self.distributor._allgather(
(local_argmin_value,
globalized_local_argmin))
local_argmin_list = np.array(local_argmin_list, dtype=[
('value', np.dtype('complex128')),
('index', np.dtype('float'))])
('value', np.dtype('complex128')),
('index', np.dtype('float'))])
local_argmin_list = np.sort(local_argmin_list,
order=['value', 'index'])
return np.int(local_argmin_list[0][1])
......@@ -564,19 +567,19 @@ class distributed_data_object(object):
else:
local_argmax = np.argmax(self.data)
local_argmax_value = -self.data[np.unravel_index(local_argmax,
self.data.shape)]
self.data.shape)]
globalized_local_argmax = self.distributor.globalize_flat_index(
local_argmax)
local_argmax_list = self.distributor._allgather((local_argmax_value,
globalized_local_argmax))
local_argmax)
local_argmax_list = self.distributor._allgather(
(local_argmax_value,
globalized_local_argmax))
local_argmax_list = np.array(local_argmax_list, dtype=[
('value', np.dtype('complex128')),
('index', np.dtype('float'))])
('value', np.dtype('complex128')),
('index', np.dtype('float'))])
local_argmax_list = np.sort(local_argmax_list,
order=['value', 'index'])
return np.int(local_argmax_list[0][1])
def argmin_nonflat(self):
return np.unravel_index(self.argmin(), self.shape)
......@@ -589,12 +592,11 @@ class distributed_data_object(object):
temp_d2o.set_local_data(temp_data)
return temp_d2o
def conj(self):
return self.conjugate()
def median(self):
about.warnings.cprint(\
about.warnings.cprint(
"WARNING: The current implementation of median is very expensive!")
median = np.median(self.get_full_data())
return median
......@@ -609,7 +611,6 @@ class distributed_data_object(object):
temp_d2o.set_local_data(np.isreal(self.data))
return temp_d2o
def all(self):
local_all = np.all(self.get_local_data())
global_all = self.distributor._allgather(local_all)
......@@ -626,32 +627,31 @@ class distributed_data_object(object):
global_unique = np.concatenate(global_unique)
return np.unique(global_unique)
def bincount(self, weights = None, minlength = None):
def bincount(self, weights=None, minlength=None):
if self.dtype not in [np.dtype('int16'), np.dtype('int32'),
np.dtype('int64'), np.dtype('uint16'),
np.dtype('uint32'), np.dtype('uint64')]:
np.dtype('int64'), np.dtype('uint16'),
np.dtype('uint32'), np.dtype('uint64')]:
raise TypeError(about._errors.cstring(
"ERROR: Distributed-data-object must be of integer datatype!"))
minlength = max(self.amax()+1, minlength)
minlength = max(self.amax() + 1, minlength)
if weights is not None:
local_weights = self.distributor.extract_local_data(weights).\
flatten()
flatten()
else:
local_weights = None
local_counts = np.bincount(self.get_local_data().flatten(),
weights = local_weights,
minlength = minlength)
weights=local_weights,
minlength=minlength)
if self.distribution_strategy == 'not':
return local_counts
else:
list_of_counts = self.distributor._allgather(local_counts)
counts = np.sum(list_of_counts, axis = 0)
counts = np.sum(list_of_counts, axis=0)
return counts
def where(self):
return self.distributor.where(self.data)
......@@ -672,11 +672,13 @@ class distributed_data_object(object):
"""
self.hermitian = hermitian
if copy == True:
if copy is True:
self.data[:] = data
else:
self.data = np.array(data, dtype=self.dtype,
copy=False, order='C').reshape(self.local_shape)
self.data = np.array(data,
dtype=self.dtype,
copy=False,
order='C').reshape(self.local_shape)
def set_data(self, data, to_key, from_key=None, local_keys=False,
hermitian=False, copy=True, **kwargs):
......@@ -700,15 +702,15 @@ class distributed_data_object(object):
"""
self.hermitian = hermitian
self.distributor.disperse_data(data = self.data,
to_key = to_key,
data_update = data,
from_key = from_key,
local_keys = local_keys,
copy = copy,
self.distributor.disperse_data(data=self.data,
to_key=to_key,
data_update=data,
from_key=from_key,
local_keys=local_keys,
copy=copy,
**kwargs)
def set_full_data(self, data, hermitian=False, copy = True, **kwargs):
def set_full_data(self, data, hermitian=False, copy=True, **kwargs):
"""
Distributes the supplied data to the nodes. The shape of data must
match the shape of the distributed_data_object.
......@@ -729,7 +731,7 @@ class distributed_data_object(object):
"""
self.hermitian = hermitian
self.data = self.distributor.distribute_data(data=data, copy = copy,
self.data = self.distributor.distribute_data(data=data, copy=copy,
**kwargs)
def get_local_data(self, key=(slice(None),), copy=True):
......@@ -747,9 +749,9 @@ class distributed_data_object(object):
self.data[key] : numpy.ndarray
"""
if copy == True:
if copy is True:
return self.data[key]
if copy == False:
if copy is False:
return self.data
def get_data(self, key, local_keys=False, **kwargs):
......@@ -785,10 +787,9 @@ class distributed_data_object(object):
return self.distributor.collect_data(self.data,
key,
local_keys = local_keys,
local_keys=local_keys,
**kwargs)
def get_full_data(self, target_rank='all'):
"""
Fully consolidates the distributed data.
......@@ -810,7 +811,7 @@ class distributed_data_object(object):
"""
return self.distributor.consolidate_data(self.data,
target_rank = target_rank)
target_rank=target_rank)
def inject(self, to_key=(slice(None),), data=None,
from_key=(slice(None),)):
......@@ -818,38 +819,35 @@ class distributed_data_object(object):
return self
self.distributor.inject(self.data, to_key, data, from_key)
def flatten(self, inplace = False):
flat_data = self.distributor.flatten(self.data, inplace = inplace)
def flatten(self, inplace=False):
flat_data = self.distributor.flatten(self.data, inplace=inplace)
flat_global_shape = (np.prod(self.shape),)
flat_local_shape = np.shape(flat_data)
# Try to keep the distribution strategy. Therefore
# create an empty copy of self which has the new shape
temp_d2o = self.copy_empty(global_shape = flat_global_shape,
local_shape = flat_local_shape)
temp_d2o = self.copy_empty(global_shape=flat_global_shape,
local_shape=flat_local_shape)
# Check if the local shapes match.
if temp_d2o.local_shape == flat_local_shape:
work_d2o = temp_d2o
# if the shapes do not match, create a freeform d2o
else:
work_d2o = self.copy_empty(local_shape = flat_local_shape,
distribution_strategy = 'freeform')
work_d2o = self.copy_empty(local_shape=flat_local_shape,
distribution_strategy='freeform')
# Feed the work_d2o with the flat data
work_d2o.set_local_data(data = flat_data,
copy = False)
work_d2o.set_local_data(data=flat_data,
copy=False)
if inplace == True:
if inplace is True:
self = work_d2o
return self
else:
return work_d2o
def save(self, alias, path=None, overwriteQ=True):
"""
Saves a distributed_data_object to disk utilizing h5py.
......@@ -884,17 +882,16 @@ class distributed_data_object(object):
self.data = self.distributor.load_data(alias, path)
class _distributor_factory(object):
def __init__(self):
self.distributor_store = {}
def parse_kwargs(self, distribution_strategy, comm,
global_data = None, global_shape = None,
local_data = None, local_shape = None,
alias = None, path = None,
dtype = None, **kwargs):
global_data=None, global_shape=None,
local_data=None, local_shape=None,