Commit d8f0ba70 authored by ultimanet's avatar ultimanet
Browse files

Misc changes. Separation of field and space functionalities further improved....

Misc changes. Separation of field and space functionalities further improved. The fft of a field now runs purely with a distributed data object as it is now the natural data object of rg_space.
parent dc144ef8
......@@ -213,6 +213,11 @@ class lm_space(point_space):
"""
return self.paradict['mmax']
def shape(self):
mmax = self.paradict('mmax')
lmax = self.paradict('lmax')
return np.array([(mmax+1)*(lmax+1)-(lmax+1)*(mmax//2)], dtype=int)
def dim(self,split=False):
"""
Computes the dimension of the space, i.e.\ the number of spherical
......@@ -237,9 +242,11 @@ class lm_space(point_space):
"""
## dim = (mmax+1)*(lmax-mmax/2+1)
if(split):
return np.array([(self.para[0]+1)*(self.para[1]+1)-(self.para[1]+1)*self.para[1]//2],dtype=np.int)
return self.shape()
#return np.array([(self.para[0]+1)*(self.para[1]+1)-(self.para[1]+1)*self.para[1]//2],dtype=np.int)
else:
return (self.para[0]+1)*(self.para[1]+1)-(self.para[1]+1)*self.para[1]//2
return np.prod(self.shape())
#return (self.para[0]+1)*(self.para[1]+1)-(self.para[1]+1)*self.para[1]//2
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -1035,7 +1042,7 @@ class gl_space(point_space):
self.datatype = datatype
self.discrete = False
self.vol = gl.vol(self.para[0],nlon=self.para[1]).astype(self.datatype)
self.vol = gl.vol(self.paradict['nlat'],nlon=self.paradict['nlon']).astype(self.datatype)
@property
......@@ -1074,6 +1081,9 @@ class gl_space(point_space):
"""
return self.paradict['nlon']
def shape(self):
return np.array([(self.paradict['nlat']*self.paradict['nlon'])], dtype=np.int)
def dim(self,split=False):
"""
Computes the dimension of the space, i.e.\ the number of pixels.
......@@ -1091,9 +1101,11 @@ class gl_space(point_space):
"""
## dim = nlat*nlon
if(split):
return np.array([self.para[0]*self.para[1]],dtype=np.int)
return self.shape()
#return np.array([self.para[0]*self.para[1]],dtype=np.int)
else:
return self.para[0]*self.para[1]
return np.prod(self.shape())
#return self.para[0]*self.para[1]
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -1699,7 +1711,7 @@ class hp_space(point_space):
self.datatype = np.float64
self.discrete = False
self.vol = np.array([4*pi/(12*self.para[0]**2)],dtype=self.datatype)
self.vol = np.array([4*pi/(12*self.paradict['nside']**2)],dtype=self.datatype)
@property
def para(self):
......@@ -1725,6 +1737,8 @@ class hp_space(point_space):
"""
return self.paradict['nside']
def shape(self):
return np.array([12*self.paradict['nside']**2], dtype=np.int)
def dim(self,split=False):
"""
......@@ -1743,9 +1757,11 @@ class hp_space(point_space):
"""
## dim = 12*nside**2
if(split):
return np.array([12*self.para[0]**2],dtype=np.int)
return self.shape()
#return np.array([12*self.para[0]**2],dtype=np.int)
else:
return 12*self.para[0]**2
return np.prod(self.shape())
#return 12*self.para[0]**2
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
......@@ -1011,12 +1011,9 @@ class space(object):
self.discrete = True
self.vol = np.real(np.array([1],dtype=self.datatype))
self.shape = None
@property
def para(self):
return self.paradict['default']
#return self.distributed_val
@para.setter
def para(self, x):
......@@ -1033,12 +1030,12 @@ class space(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def getitem(self, key):
def getitem(self, data, key):
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'getitem'."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def setitem(self, key):
def setitem(self, data, key):
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'getitem'."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -1053,6 +1050,8 @@ class space(object):
def norm(self, x, q):
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'norm'."))
def shape(self):
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'shape'."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dim(self,split=False):
......@@ -1217,8 +1216,30 @@ class space(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast(self, x):
return self.enforce_values(x)
def cast(self, x, verbose=False):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
return self.enforce_values(x, extend=True)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -1827,13 +1848,24 @@ class point_space(space):
def para(self):
temp = np.array([self.paradict['num']], dtype=int)
return temp
#return self.distributed_val
@para.setter
def para(self, x):
self.paradict['num'] = x
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def getitem(self, data, key):
return data[key]
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def setitem(self, data, update, key):
data[key]=update
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def unary_operation(self, x, op='None', **kwargs):
"""
x must be a numpy array which is compatible with the space!
......@@ -1947,6 +1979,10 @@ class point_space(space):
"""
return self.para[0]
def shape(self):
return np.array([self.paradict['num']])
def dim(self,split=False):
"""
Computes the dimension of the space, i.e.\ the number of points.
......@@ -1964,9 +2000,11 @@ class point_space(space):
"""
## dim = num
if(split):
return np.array([self.para[0]],dtype=np.int)
return self.shape()
#return np.array([self.para[0]],dtype=np.int)
else:
return self.para[0]
return np.prod(self.shape())
#return self.para[0]
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -4711,6 +4749,12 @@ class nested_space(space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def shape(self):
temp = []
for i in range(self.paradict.ndim):
temp = np.append(temp, self.paradict[i])
return temp
def dim(self,split=False):
"""
Computes the dimension of the product space.
......@@ -4729,9 +4773,11 @@ class nested_space(space):
Dimension(s) of the space.
"""
if(split):
return self.para
return self.shape()
#return self.para
else:
return np.prod(self.para,axis=0,dtype=None,out=None)
return np.prod(self.shape())
#return np.prod(self.para,axis=0,dtype=None,out=None)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -5436,7 +5482,10 @@ class field(object):
else:
self.domain.check_codomain(target)
self.target = target
self.val = self.domain.cast(val)
"""
self.distributed_val = distributed_data_object(global_shape=domain.dim(split=True), dtype=domain.datatype)
## check values
......@@ -5454,6 +5503,8 @@ class field(object):
@val.setter
def val(self, x):
return self.distributed_val.set_full_data(x)
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dim(self,split=False):
......@@ -5792,7 +5843,7 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def transform(self,target=None,overwrite=False,**kwargs):
def transform(self, target=None, overwrite=False, **kwargs):
"""
Computes the transform of the field using the appropriate conjugate
transformation.
......@@ -5821,12 +5872,16 @@ class field(object):
target = self.target
else:
self.domain.check_codomain(target) ## a bit pointless
new_val = self.domain.calc_transform(self.val,
codomain=target,
**kwargs)
if(overwrite):
self.val = self.domain.calc_transform(self.val,codomain=target,field_val=self.distributed_val, **kwargs)
self.val = new_val
self.target = self.domain
self.domain = target
else:
return field(target,val=self.domain.calc_transform(self.val,codomain=target, field_val=self.distributed_val, **kwargs),target=self.domain)
return field(target, val=new_val, target=self.domain)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
This diff is collapsed.
......@@ -29,7 +29,14 @@ class space_paradict(object):
self.parameters = {}
for key in kwargs:
self[key] = kwargs[key]
def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return self.parameters.__repr__()
......
# -*- coding: utf-8 -*-
import numpy as np
from nifty import nifty_mpi_data
from nifty.nifty_mpi_data import distributed_data_object
# Try to import pyfftw. If this fails fall back to gfft. If this fails fall back to local gfft_rg
......@@ -50,7 +50,7 @@ class fft(object):
----------
None
"""
def transform(self,field_val,domain,codomain,**kwargs):
def transform(self, val, domain, codomain, **kwargs):
"""
A generic ff-transform function.
......@@ -71,9 +71,10 @@ class fft(object):
if fft_machine == 'pyfftw':
## The instances of plan_and_info store the fftw plan and all
## other information needed in order to perform a mpi-fftw transformation
class _fftw_plan_and_info(object):
class _fftw_plan_and_info(fft):
def __init__(self,domain,codomain,fft_fftw_context,**kwargs):
self.compute_plan_and_info(domain,codomain,fft_fftw_context,**kwargs)
self.compute_plan_and_info(domain, codomain, fft_fftw_context,
**kwargs)
def set_plan(self, x):
self.plan=x
......@@ -90,19 +91,21 @@ if fft_machine == 'pyfftw':
def get_codomain_centering_mask(self):
return self.codomain_centering_mask
def compute_plan_and_info(self, domain, codomain,fft_fftw_context,**kwargs):
def compute_plan_and_info(self, domain, codomain, fft_fftw_context,
**kwargs):
self.input_dtype = 'complex128'
self.output_dtype = 'complex128'
self.global_input_shape = domain.dim(split=True)
self.global_output_shape = codomain.dim(split=True)
self.global_input_shape = domain.shape()
self.global_output_shape = codomain.shape()
self.fftw_local_size = pyfftw.local_size(self.global_input_shape)
self.in_zero_centered_dimensions = domain.zerocenter()[::-1]
self.out_zero_centered_dimensions = codomain.zerocenter()[::-1]
self.in_zero_centered_dimensions = domain.paradict['zerocenter']
self.out_zero_centered_dimensions = codomain.paradict['zerocenter']
self.local_node_dimensions = np.append((self.fftw_local_size[1],),self.global_input_shape[1:])
self.local_node_dimensions = np.append((self.fftw_local_size[1],),
self.global_input_shape[1:])
self.offsetQ = self.fftw_local_size[2]%2
if codomain.fourier == True:
......@@ -147,10 +150,12 @@ if fft_machine == 'pyfftw':
## to a certain set of (field_val, domain, codomain) sets.
self.plan_dict = {}
## initialize the dictionary which stores the values from get_centering_mask
## initialize the dictionary which stores the values from
## get_centering_mask
self.centering_mask_dict = {}
def get_centering_mask(self, to_center_input, dimensions_input, offset_input=0):
def get_centering_mask(self, to_center_input, dimensions_input,
offset_input=0):
"""
Computes the mask, used to (de-)zerocenter domain and target
fields.
......@@ -177,7 +182,8 @@ if fft_machine == 'pyfftw':
to_center = np.array(to_center_input)
dimensions = np.array(dimensions_input)
if np.all(dimensions == np.array(1)) or np.all(dimensions == np.array([1])):
if np.all(dimensions == np.array(1)) or \
np.all(dimensions == np.array([1])):
return dimensions
## The dimensions of size 1 must be sorted out for computing the
## centering_mask. The depth of the array will be restored in the
......@@ -199,7 +205,8 @@ if fft_machine == 'pyfftw':
offset[0] = int(offset_input)
## check for dimension match
if to_center.size != dimensions.size:
raise TypeError('The length of the supplied lists does not match.')
raise TypeError(\
'The length of the supplied lists does not match.')
## build up the value memory
## compute an identifier for the parameter set
......@@ -207,7 +214,13 @@ if fft_machine == 'pyfftw':
if not temp_id in self.centering_mask_dict:
## use np.tile in order to stack the core alternation scheme
## until the desired format is constructed.
core = np.fromfunction(lambda *args : (-1)**(np.tensordot(to_center,args+offset.reshape(offset.shape+(1,)*(np.array(args).ndim-1)),1)), (2,)*to_center.size)
core = np.fromfunction(
lambda *args : (-1)**\
(np.tensordot(to_center,args + \
offset.reshape(offset.shape + \
(1,)*(np.array(args).ndim - 1)),1)),\
(2,)*to_center.size)
centering_mask = np.tile(core,dimensions//2)
## for the dimensions of odd size corresponding slices must be added
for i in range(centering_mask.ndim):
......@@ -215,9 +228,11 @@ if fft_machine == 'pyfftw':
if (dimensions%2)[i]==0:
continue
## prepare the slice object
temp_slice=(slice(None),)*i + (slice(-2,-1,1),) + (slice(None),)*(centering_mask.ndim -1 - i)
temp_slice = (slice(None),)*i + (slice(-2,-1,1),) +\
(slice(None),)*(centering_mask.ndim -1 - i)
## append the slice to the centering_mask
centering_mask = np.append(centering_mask,centering_mask[temp_slice],axis=i)
centering_mask = np.append(centering_mask,
centering_mask[temp_slice],axis=i)
## Add depth to the centering_mask where the length of a
## dimension was one
temp_slice = ()
......@@ -236,16 +251,17 @@ if fft_machine == 'pyfftw':
temp_id = (domain.__identifier__(), codomain.__identifier__())
## generate the plan_and_info object if not already there
if not temp_id in self.plan_dict:
self.plan_dict[temp_id]=_fftw_plan_and_info(domain,codomain,self,**kwargs)
self.plan_dict[temp_id]=_fftw_plan_and_info(domain, codomain,
self, **kwargs)
return self.plan_dict[temp_id]
def transform(self,val,domain,codomain, field_val, **kwargs):
def transform(self, val, domain, codomain, **kwargs):
"""
The pyfftw transform function.
Parameters
----------
field_val : distributed_data_object
val : distributed_data_object or numpy.ndarray
The value-array of the field which is supposed to
be transformed.
......@@ -263,32 +279,72 @@ if fft_machine == 'pyfftw':
result : np.ndarray
Fourier-transformed pendant of the input field.
"""
current_plan_and_info=self._get_plan_and_info(domain,codomain,**kwargs)
## Prepare the input data
current_plan_and_info=self._get_plan_and_info(domain, codomain,
**kwargs)
## Prepare the environment variables
local_size = current_plan_and_info.fftw_local_size
local_start = local_size[2]
local_end = local_start + local_size[1]
val = field_val.get_data(slice(local_start,local_end))
val *= current_plan_and_info.get_codomain_centering_mask()
## Prepare the input data
## Case 1: val is a distributed_data_object
if isinstance(val, distributed_data_object):
return_val = val.copy_empty(global_shape =\
tuple(current_plan_and_info.global_output_shape),
dtype = np.complex128)
## If the distribution strategy of the d2o is fftw, extract
## the data directly
if val.distribution_strategy == 'fftw':
local_val = val.get_local_data()
else:
local_val = val.get_data(slice(local_start, local_end))
## Case 2: val is a numpy array carrying the full data
else:
local_val = val[slice(local_start, local_end)]
local_val *= current_plan_and_info.get_codomain_centering_mask()
## Define a abbreviation for the fftw plan
p = current_plan_and_info.get_plan()
## load the field into the plan
if p.has_input:
p.input_array[:] = val
p.input_array[:] = local_val
## execute the plan
p()
result = p.output_array*current_plan_and_info.get_domain_centering_mask()
result = p.output_array * current_plan_and_info.\
get_domain_centering_mask()
## renorm the result according to the convention of gfft
if current_plan_and_info.direction == 'FFTW_FORWARD':
result = result/float(result.size)
else:
result *= float(result.size)
## build a distributed_data_object
data_object = nifty_mpi_data.distributed_data_object(global_shape = tuple(current_plan_and_info.global_output_shape), dtype = np.complex128, distribution_strategy='fftw')
data_object.set_local_data(data=result)
return data_object.get_full_data()
## build the return object according to the input val
try:
if return_val.distribution_strategy == 'fftw':
return_val.set_local_data(data = result)
else:
return_val.set_data(data = result,
key = slice(local_start, local_end))
## If the values living in domain are purely real, the
## result of the fft is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
## In case the input val was not a distributed data obect, the try
## will produce a NameError
except(NameError):
return_val = distributed_data_object(
global_shape =\
tuple(current_plan_and_info.global_output_shape),
dtype = np.complex128,
distribution_strategy='fftw')
return_val.set_local_data(data = result)
return_val = return_val.get_full_data()
return return_val
......@@ -309,7 +365,7 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
Parameters
----------
val : numpy.ndarray
val : numpy.ndarray or distributed_data_object
The value-array of the field which is supposed to
be transformed.
......@@ -332,9 +388,43 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
ftmachine = "fft"
else:
ftmachine = "ifft"
## if the input is a distributed_data_object, extract the data
if isinstance(val, distributed_data_object):
d2oQ = True
val = val.get_full_data()
## transform and return
if(domain.datatype==np.float64):
return gfft.gfft(val.astype(np.complex128),in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
temp = gfft.gfft(val.astype(np.complex128),
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
in_zero_center=domain.para[-naxes:].\
astype(np.bool).tolist(),
out_zero_center=codomain.para[-naxes:].\
astype(np.bool).tolist(),
enforce_hermitian_symmetry = \
bool(codomain.para[naxes]==1),
W=-1,
alpha=-1,
verbose=False)
else:
return gfft.gfft(val,in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
\ No newline at end of file
temp = gfft.gfft(val,
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
in_zero_center=domain.para[-naxes:].\
astype(np.bool).tolist(),
out_zero_center=codomain.para[-naxes:].\
astype(np.bool).tolist(),
enforce_hermitian_symmetry = \
bool(codomain.para[naxes]==1),
W=-1,
alpha=-1,
verbose=False)
if d2oQ == True:
val.set_full_data(temp)
else:
val = temp
return val
\ No newline at end of file
This diff is collapsed.
......@@ -438,6 +438,8 @@ class power_indices(object):
else:
k = kindex
dk = np.max(k[2:]-k[1:-1]) ## minimal dk
print ('k', k)
print ('dk', dk)
if(nbin is None):
nbin = int((k[-1]-0.5*(k[2]+k[1]))/dk-0.5) ## maximal nbin
else:
......@@ -446,6 +448,7 @@ class power_indices(object):
binbounds = np.r_[0.5*(3*k[1]-k[2]),0.5*(k[1]+k[2])+dk*np.arange(nbin-2)]
if(log):
binbounds = np.exp(binbounds)
print nbin
## reordering
reorder = np.searchsorted(binbounds,kindex)
rho_ = np.zeros(len(binbounds)+1,dtype=rho.dtype)
......@@ -475,7 +478,7 @@ if __name__ == '__main__':
comm = MPI.COMM_WORLD
rank = comm.rank
size = comm.size
p = power_indices((4,4),(1,1), zerocentered=(True,True), nbin = 4)
p = power_indices((4,4),(1,1), zerocentered=(True,True), nbin = 5)
"""
obj = p.default_indices['nkdict']
for i in np.arange(size):
......@@ -1062,13 +1065,14 @@ def nhermitianize_fast(field,zerocentered,special=False):
index = tuple(ii*maxindex)
field[index] *= np.sqrt(0.5)
else: ## regular case
field = 0.5*(field+dummy)
#field = 0.5*(field+dummy)
field = dummy
## reshift zerocentered axes
if(np.any(zerocentered==True)):
field = np.fft.fftshift(field,axes=shiftaxes(zerocentered))
return field
def random_hermitian_pm1(datatype,zerocentered,shape):