Commit 3e0aa18d authored by Ultimanet's avatar Ultimanet
Browse files

rg_space now fully d2o capable

parent 9c1ab786
......@@ -20,6 +20,10 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
import matplotlib as mpl
mpl.use('Agg')
from nifty_about import about
from nifty_cmaps import ncmap
from nifty_core import space,\
......
......@@ -36,7 +36,7 @@ from nifty import * # version
# some signal space; e.g., a one-dimensional regular grid
x_space = rg_space(128) # define signal space
x_space = rg_space([128,]) # define signal space
k_space = x_space.get_codomain() # get conjugate space
......
......@@ -87,7 +87,8 @@ class _COMM_WORLD():
class _datatype():
def __init__(self, name):
self.name = str(name)
BYTE = _datatype('MPI_BYTE')
SHORT = _datatype('MPI_SHORT')
UNSIGNED_SHORT = _datatype("MPI_UNSIGNED_SHORT")
UNSIGNED_INT = _datatype("MPI_UNSIGNED_INT")
......
......@@ -1335,7 +1335,10 @@ class point_space(space):
"""
self.set_power_indices(**kwargs)
return self.power_indices.get("kindex"),self.power_indices.get("rho"),self.power_indices.get("pindex"),self.power_indices.get("pundex")
return self.power_indices.get("kindex"),\
self.power_indices.get("rho"),\
self.power_indices.get("pindex"),\
self.power_indices.get("pundex")
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -1472,7 +1475,7 @@ class point_space(space):
vmax : float, *optional*
Upper limit for a uniform distribution (default: 1).
"""
arg = random.arguments(self,**kwargs)
arg = random.parse_arguments(self,**kwargs)
if(arg is None):
x = np.zeros(self.dim(split=True),dtype=self.datatype,order='C')
......@@ -2111,7 +2114,7 @@ class nested_space(space):
vmax : float, *optional*
Upper limit for a uniform distribution (default: 1).
"""
arg = random.arguments(self,**kwargs)
arg = random.parse_arguments(self,**kwargs)
if(arg is None):
return np.zeros(self.dim(split=True),dtype=self.datatype,order='C')
......@@ -3263,14 +3266,22 @@ class field(object):
corresponding :py:meth:`get_plot` method.
"""
interactive = pl.isinteractive()
pl.matplotlib.interactive(not bool(kwargs.get("save",False)))
## if a save path is given, set pylab to not-interactive
remember_interactive = pl.isinteractive()
pl.matplotlib.interactive(not bool(
kwargs.get("save", False)
)
)
if("codomain" in kwargs):
if "codomain" in kwargs:
kwargs.__delitem__("codomain")
self.domain.get_plot(self.val,codomain=self.target,**kwargs)
about.warnings.cprint("WARNING: codomain was removed from kwargs.")
pl.matplotlib.interactive(interactive)
## draw/save the plot(s)
self.domain.get_plot(self.val, codomain=self.target, **kwargs)
## restore the pylab interactiveness
pl.matplotlib.interactive(remember_interactive)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
......@@ -100,7 +100,7 @@ class distributed_data_object(object):
distribution_strategy='fftw', hermitian=False,
*args, **kwargs):
if global_data != None:
if np.array(global_data).shape == ():
if np.isscalar(global_data):
global_data_input = None
dtype = np.array(global_data).dtype.type
else:
......@@ -114,6 +114,7 @@ class distributed_data_object(object):
global_data=global_data_input,
global_shape=global_shape,
dtype=dtype, **kwargs)
self.set_full_data(data=global_data_input, hermitian=hermitian,
**kwargs)
......@@ -126,7 +127,7 @@ class distributed_data_object(object):
self.init_kwargs = kwargs
## If the input data was a scalar, set the whole array to this value
if global_data != None and np.array(global_data).shape == ():
if global_data != None and np.isscalar(global_data):
temp = np.empty(self.distributor.local_shape)
temp.fill(global_data)
self.set_local_data(temp)
......@@ -187,6 +188,35 @@ class distributed_data_object(object):
return '<distributed_data_object>\n'+self.data.__repr__()
def __eq__(self, other):
result = self.copy_empty(dtype = np.bool)
## Case 1: 'other' is a scalar
## -> make point-wise comparison
if np.isscalar(other):
result.set_local_data(self.get_local_data(copy = False) == other)
return result
## Case 2: 'other' is a numpy array or a distributed_data_object
## -> extract the local data and make point-wise comparison
elif isinstance(other, np.ndarray) or\
isinstance(other, distributed_data_object):
temp_data = self.distributor.extract_local_data(other)
result.set_local_data(self.get_local_data(copy=False) == temp_data)
return result
## Case 3: 'other' is None
elif other == None:
return False
## Case 4: 'other' is something different
## -> make a numpy casting and make a recursion
else:
temp_other = np.array(other)
return self.__eq__(temp_other)
def equal(self, other):
if other is None:
return False
try:
......@@ -196,7 +226,7 @@ class distributed_data_object(object):
assert(self.init_kwargs == other.init_kwargs)
assert(self.distribution_strategy == other.distribution_strategy)
assert(np.all(self.data == other.data))
except(AssertionError):
except(AssertionError, AttributeError):
return False
else:
return True
......@@ -215,27 +245,44 @@ class distributed_data_object(object):
return temp_d2o
def __abs__(self):
temp_d2o = self.copy_empty()
## translate complex dtypes
if self.dtype == np.complex64:
new_dtype = np.float32
elif self.dtype == np.complex128:
new_dtype = np.float64
elif self.dtype == np.complex:
new_dtype = np.float
elif issubclass(self.dtype, np.complexfloating):
new_dtype = np.float
else:
new_dtype = self.dtype
temp_d2o = self.copy_empty(dtype = new_dtype)
temp_d2o.set_local_data(data = self.get_local_data().__abs__())
return temp_d2o
def __builtin_helper__(self, operator, other):
temp_d2o = self.copy_empty()
if not np.isscalar(other):
new_other = self.copy_empty()
new_other.set_full_data(np.array(other))
other = new_other
## Case 1: other is not a scalar
if not (np.isscalar(other) or np.shape(other) == (1,)):
## if self.shape != other.shape:
## raise AttributeError(about._errors.cstring(
## "ERROR: Shapes do not match!"))
## extract the local data from the 'other' object
temp_data = self.distributor.extract_local_data(other)
temp_data = operator(temp_data)
if isinstance(other, distributed_data_object):
temp_data = operator(other.get_local_data())
else:
temp_data = operator(other)
## write the new data into a new distributed_data_object
temp_d2o = self.copy_empty()
temp_d2o.set_local_data(data=temp_data)
return temp_d2o
def __inplace_builtin_helper__(self, operator, other):
if isinstance(other, distributed_data_object):
temp_data = operator(other.get_local_data())
if not (np.isscalar(other) or np.shape(other) == (1,)):
temp_data = self.distributor.extract_local_data(other)
temp_data = operator(temp_data)
else:
temp_data = operator(other)
self.set_local_data(data=temp_data)
......@@ -319,7 +366,28 @@ class distributed_data_object(object):
def __getitem__(self, key):
return self.get_data(key)
## Case 1: key is a boolean array.
## -> take the local data portion from key, use this for data
## extraction, and then merge the result in a flat numpy array
if isinstance(key, np.ndarray):
found = 'ndarray'
found_boolean = (key.dtype.type == np.bool)
elif isinstance(key, distributed_data_object):
found = 'd2o'
found_boolean = (key.dtype == np.bool)
else:
found = 'other'
if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
## extract the data of local relevance
local_bool_array = self.distributor.extract_local_data(key)
local_results = self.get_local_data(copy=False)[local_bool_array]
global_results = self.distributor._allgather(local_results)
global_results = np.concatenate(global_results)
return global_results
else:
return self.get_data(key)
def __setitem__(self, key, data):
self.set_data(data, key)
......@@ -508,7 +576,7 @@ class distributed_data_object(object):
self.data = self.distributor.distribute_data(data=data, **kwargs)
def get_local_data(self, key=(slice(None),)):
def get_local_data(self, key=(slice(None),), copy=True):
"""
Loads data directly from the local data attribute. No consolidation
is done.
......@@ -523,7 +591,10 @@ class distributed_data_object(object):
self.data[key] : numpy.ndarray
"""
return self.data[key]
if copy == True:
return self.data[key]
if copy == False:
return self.data
def get_data(self, key, **kwargs):
"""
......@@ -697,23 +768,25 @@ class _fftw_distributor(object):
if comm.rank == 0:
## Case 1: hdf5 path supplied
if alias != None:
self.global_shape = dset.shape
else:
if global_data == None or np.array(global_data).shape == ():
## Case 2: no hdf5 path supplied
else:
## subcase 1: input data is scalar or None
if global_data == None or np.isscalar(global_data):
if global_shape == None:
raise TypeError(about._errors.\
cstring("ERROR: Neither data nor shape supplied!"))
cstring("ERROR: Neither non-scalar data nor shape supplied!"))
else:
self.global_shape = global_shape
## subcase 2: input data is non-scalar
## -> Take the shape of the input data
else:
self.global_shape = global_data.shape
else:
self.global_shape = None
self.global_shape = comm.bcast(self.global_shape, root = 0)
self.global_shape = tuple(self.global_shape)
......@@ -889,18 +962,33 @@ class _fftw_distributor(object):
if from_slices == None:
update_slice = (slice(o[r], o[r]+l),)
else:
f_start = from_slices[0].start
f_step = from_slices[0].step
if f_step == None:
f_step = 1
f_direction = np.sign(f_step)
f_relative_start = from_slices[0].start
if f_relative_start != None:
f_start = f_relative_start + f_direction*o[r]
else:
f_start = None
f_relative_start = 0
f_stop = f_relative_start + f_direction*(o[r]+l*np.abs(f_step))
if f_stop < 0:
f_stop = None
## combine the slicing for the first dimension
update_slice = (slice(f_start + f_direction*o[r],
f_start + f_direction*(o[r]+l),
update_slice = (slice(f_start,
f_stop,
f_step),
)
## add the rest of the from_slicing
update_slice += from_slices[1:]
data[local_slice] = np.array(data_update[update_slice],\
copy=False).astype(self.dtype)
......@@ -1105,7 +1193,7 @@ class _fftw_distributor(object):
def inject(self, data, to_slices, data_update, from_slices, comm=None,
**kwargs):
## check if to_key and from_key is completely build of slices
## check if to_key and from_key are completely built of slices
if not np.all(
np.vectorize(lambda x: isinstance(x, slice))(to_slices)):
raise ValueError(about._errors.cstring(
......@@ -1126,7 +1214,103 @@ class _fftw_distributor(object):
from_slices = from_slices,
comm=comm,
**kwargs)
def extract_local_data(self, data_object):
## if data_object is not a ndarray or a d2o, cast it to a ndarray
if not (isinstance(data_object, np.ndarray) or
isinstance(data_object, distributed_data_object)):
data_object = np.array(data_object)
## check if the shapes are remotely compatible, reshape if possible
## and determine which dimensions match only via broadcasting
try:
(data_object, matching_dimensions) = \
self._reshape_foreign_data(data_object)
## if the shape-casting fails, try to fix things via locall data
## matching
except(ValueError):
## Check if all the local shapes match the supplied data
local_matchQ = (self.local_shape == data_object.shape)
global_matchQ = self._allgather(local_matchQ)
## if the local shapes match, simply return the data_object
if np.all(global_matchQ):
extracted_data = data_object[:]
## if not, allgather the local data pieces and extract from this
else:
allgathered_data = self._allgather(data_object)
allgathered_data = np.concatenate(allgathered_data)
if allgathered_data.shape != self.global_shape:
raise ValueError(
about._errors.cstring(
"ERROR: supplied shapes do neither match globally nor locally"))
return self.extract_local_data(allgathered_data)
## if shape-casting was successfull, extract the data
else:
## If the first dimension matches only via broadcasting...
## Case 1: ...do broadcasting. This procedure does not depend on the
## array type (ndarray or d2o)
if matching_dimensions[0] == False:
extracted_data = data_object[0:1]
## Case 2: First dimension fits directly and data_object is a d2o
elif isinstance(data_object, distributed_data_object):
## Check if the distribution_strategy and the comm match
## the own ones.
if type(self) == type(data_object.distributor) and\
self.comm == data_object.distributor.comm:
## Case 1: yes. Simply take the local data
extracted_data = data_object.data
else:
## Case 2: no. All nodes extract their local slice from the
## data_object
extracted_data =\
data_object[self.local_start:self.local_end]
## Case 3: First dimension fits directly and data_object is an generic
## array
else:
extracted_data =\
data_object[self.local_start:self.local_end]
return extracted_data
def _reshape_foreign_data(self, foreign):
## Case 1:
## check if the shapes match directly
if self.global_shape == foreign.shape:
matching_dimensions = [True,]*len(self.global_shape)
return (foreign, matching_dimensions)
## Case 2:
## if not, try to reshape the input data
## in particular, this will fail when foreign is a d2o as long as
## reshaping is not implemented
try:
output = foreign.reshape(self.global_shape)
matching_dimensions = [True,]*len(self.global_shape)
return (output, matching_dimensions)
except(ValueError, AttributeError):
pass
## Case 3:
## if this does not work, try to broadcast the shape
## check if the dimensions match
if len(self.global_shape) != len(foreign.shape):
raise ValueError(
about._errors.cstring("ERROR: unequal number of dimensions!"))
## check direct matches
direct_match = (np.array(self.global_shape) == np.array(foreign.shape))
## check broadcast compatibility
broadcast_match = (np.ones(len(self.global_shape), dtype=int) ==\
np.array(foreign.shape))
## combine the matches and assert that all are true
combined_match = (direct_match | broadcast_match)
if not np.all(combined_match):
raise ValueError(
about._errors.cstring("ERROR: incompatible shapes!"))
matching_dimensions = tuple(direct_match)
return (foreign, matching_dimensions)
def consolidate_data(self, data, target_rank='all', comm = None):
if comm == None:
comm = self.comm
......@@ -1243,6 +1427,9 @@ class _not_distributor(object):
def inject(self, data, to_slices = (slice(None),), data_update = None,
from_slices = (slice(None),)):
data[to_slices] = data_update[from_slices]
def extract_local_data(self, data_object):
return data_object.get_full_data()
def save_data(self, *args, **kwargs):
raise AttributeError(about._errors.cstring(
......
......@@ -8,7 +8,7 @@ from nifty.nifty_about import about
# If this fails fall back to local gfft_rg
try:
import pyfftw
import pyfftw_BAD
fft_machine='pyfftw'
except(ImportError):
try:
......@@ -315,12 +315,14 @@ if fft_machine == 'pyfftw':
result = p.output_array * current_plan_and_info.\
get_domain_centering_mask()
"""
## renorm the result according to the convention of gfft
if current_plan_and_info.direction == 'FFTW_FORWARD':
result = result/float(result.size)
else:
result *= float(result.size)
"""
## build the return object according to the input val
try:
if return_val.distribution_strategy == 'fftw':
......@@ -391,10 +393,12 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
## if the input is a distributed_data_object, extract the data
if isinstance(val, distributed_data_object):
d2oQ = True
val = val.get_full_data()
temp = val.get_full_data()
else:
temp = val
## transform and return
if(domain.datatype==np.float64):
temp = gfft.gfft(val.astype(np.complex128),
temp = gfft.gfft(temp.astype(np.complex128),
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
......@@ -408,7 +412,7 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
alpha=-1,
verbose=False)
else:
temp = gfft.gfft(val,
temp = gfft.gfft(temp,
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
......@@ -422,7 +426,13 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
alpha=-1,
verbose=False)
if d2oQ == True:
val.set_full_data(temp)
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(temp)
## If the values living in domain are purely real, the
## result of the fft is hermitian
if domain.paradict['complexity'] == 0:
new_val.hermitian = True
val = new_val
else:
val = temp
......
This diff is collapsed.
......@@ -22,473 +22,6 @@
from __future__ import division
import numpy as np
from mpi4py import MPI
from nifty.nifty_about import about
from nifty.nifty_mpi_data import distributed_data_object
class power_indices(object):
def __init__(self, shape, dgrid, zerocentered=False, log=False, nbin=None,
binbounds=None, comm=MPI.COMM_WORLD):
"""
Returns an instance of the power_indices class. Given the shape and
the density of a underlying rectangular grid it provides the user
with the pindex, kindex, rho and pundex. The indices are bined
according to the supplied parameter scheme. If wanted, computed
results are stored for future reuse.
Parameters
----------
shape : tuple, list, ndarray
Array-like object which specifies the shape of the underlying
rectangular grid
dgrid : tuple, list, ndarray
Array-like object which specifies the step-width of the
underlying grid
zerocentered : boolean, tuple/list/ndarray of boolean *optional*
Specifies which dimensions are zerocentered. (default:False)
log : bool *optional*
Flag specifying if the binning of the default indices is
performed on logarithmic scale.
nbin : integer *optional*
Number of used bins for the binning of the default indices.
binbounds : {list, array}
Array-like inner boundaries of the used bins of the default
indices.
"""
## Basic inits and consistency checks
self.comm = comm
self.shape = np.array(shape, dtype = int)
self.dgrid = np.abs(np.array(dgrid))
if self.shape.shape != self.dgrid.shape:
raise ValueError(about._errors.cstring("ERROR: The supplied shape\
and dgrid have not the same dimensionality"))
self.zerocentered = self.__cast_zerocentered__(zerocentered)
## Compute the global kdict
self.kdict = self.compute_kdict()
## Initialize the dictonary which stores all individual index-dicts
self.global_dict={}
## Calculate the default dictonory according to the kwargs and set it
## as default
self.get_index_dict(log=log, nbin=nbin, binbounds=binbounds,
store=True)
self.set_default(log=log, nbin=nbin, binbounds=binbounds)
## Redirect the direct calls approaching a power_index instance to the
## default_indices dict
def __getitem__(self, x):
return self.default_indices.get(x)
def __getattr__(self, x):
return self.default_indices.__getattribute__(x)
def __cast_zerocentered__(self, zerocentered=False):
"""
internal helper function which brings the zerocentered input in
the form of a boolean-tuple
"""
zc = np.array(zerocentered).astype(bool)
if zc.shape == self.shape.shape:
return tuple(zc)
else:
temp = np.empty(shape=self.shape.shape, dtype=bool)
temp[:] = zc
return tuple(temp)
def __cast_config__(self, *args, **kwargs):
"""
internal helper function which casts the various combinations of
possible parameters into a properly defaulted dictionary
"""
temp_config_dict = kwargs.get('config_dict', None)
if temp_config_dict != None:
return self.__cast_config_helper__(**temp_config_dict)
else:
temp_log = kwargs.get("log", None)
temp_nbin = kwargs.get("nbin", None)
temp_binbounds = kwargs.get("binbounds", None)
return self.__cast_config_helper__(log=temp_log,
nbin=temp_nbin,
binbounds=temp_binbounds)
def __cast_config_helper__(self, log, nbin, binbounds):
"""
internal helper function which sets the defaults for the
__cast_config__ function
"""