Commit 0f24d97e authored by ultimanet's avatar ultimanet
Browse files

added distributed_data_object

parent 47ecedc4
...@@ -25,6 +25,7 @@ from nifty_cmaps import * ...@@ -25,6 +25,7 @@ from nifty_cmaps import *
from nifty_power import * from nifty_power import *
from nifty_tools import * from nifty_tools import *
from nifty_explicit import * from nifty_explicit import *
from nifty_mpi_data import distributed_data_object
## optional submodule `rg` ## optional submodule `rg`
try: try:
...@@ -42,3 +43,4 @@ except(ImportError): ...@@ -42,3 +43,4 @@ except(ImportError):
from demos import * from demos import *
from pickling import * from pickling import *
#import pyximport; pyximport.install(pyimport = True)
\ No newline at end of file
...@@ -148,7 +148,7 @@ import pylab as pl ...@@ -148,7 +148,7 @@ import pylab as pl
from multiprocessing import Pool as mp from multiprocessing import Pool as mp
from multiprocessing import Value as mv from multiprocessing import Value as mv
from multiprocessing import Array as ma from multiprocessing import Array as ma
from nifty_mpi_data import distributed_data_object
__version__ = "1.0.6" __version__ = "1.0.6"
...@@ -4983,12 +4983,24 @@ class field(object): ...@@ -4983,12 +4983,24 @@ class field(object):
else: else:
self.domain.check_codomain(target) self.domain.check_codomain(target)
self.target = target self.target = target
self.distributed_val = distributed_data_object(global_shape=domain.dim(split=True), dtype=domain.datatype)
## check values ## check values
if(val is None): if(val is None):
self.val = self.domain.get_random_values(codomain=self.target,**kwargs) self.val = self.domain.get_random_values(codomain=self.target,**kwargs)
else: else:
self.val = self.domain.enforce_values(val,extend=True) self.val = self.domain.enforce_values(val,extend=True)
@property
def val(self):
return self.distributed_val.get_full_data()
#return self.distributed_val
@val.setter
def val(self, x):
return self.distributed_val.set_full_data(x)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dim(self,split=False): def dim(self,split=False):
...@@ -5357,11 +5369,11 @@ class field(object): ...@@ -5357,11 +5369,11 @@ class field(object):
else: else:
self.domain.check_codomain(target) ## a bit pointless self.domain.check_codomain(target) ## a bit pointless
if(overwrite): if(overwrite):
self.val = self.domain.calc_transform(self.val,codomain=target,**kwargs) self.val = self.domain.calc_transform(self.val,codomain=target,field_val=self.distributed_val, **kwargs)
self.target = self.domain self.target = self.domain
self.domain = target self.domain = target
else: else:
return field(target,val=self.domain.calc_transform(self.val,codomain=target,**kwargs),target=self.domain) return field(target,val=self.domain.calc_transform(self.val,codomain=target, field_val=self.distributed_val, **kwargs),target=self.domain)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np import numpy as np
from nifty import nifty_mpi_data
# Try to import pyfftw. If this fails fall back to gfft. If this fails fall back to local gfft_rg # Try to import pyfftw. If this fails fall back to gfft. If this fails fall back to local gfft_rg
...@@ -54,7 +55,7 @@ class fft(object): ...@@ -54,7 +55,7 @@ class fft(object):
Parameters Parameters
---------- ----------
field_val : numpy.ndarray field_val : distributed_data_object
The value-array of the field which is supposed to The value-array of the field which is supposed to
be transformed. be transformed.
...@@ -140,8 +141,14 @@ if fft_machine == 'pyfftw': ...@@ -140,8 +141,14 @@ if fft_machine == 'pyfftw':
None None
""" """
## initialize the dictionary which stores the values from get_centering_mask def __init__(self):
centering_mask_dict = {} ## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
self.plan_dict = {}
## initialize the dictionary which stores the values from get_centering_mask
self.centering_mask_dict = {}
def get_centering_mask(self, to_center_input, dimensions_input, offset_input=0): def get_centering_mask(self, to_center_input, dimensions_input, offset_input=0):
""" """
Computes the mask, used to (de-)zerocenter domain and target Computes the mask, used to (de-)zerocenter domain and target
...@@ -197,9 +204,7 @@ if fft_machine == 'pyfftw': ...@@ -197,9 +204,7 @@ if fft_machine == 'pyfftw':
self.centering_mask_dict[temp_id] = centering_mask self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id] return self.centering_mask_dict[temp_id]
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
plan_dict = {}
def _get_plan_and_info(self,domain,codomain,**kwargs): def _get_plan_and_info(self,domain,codomain,**kwargs):
## generate a id-tuple which identifies the domain-codomain setting ## generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__identifier__(), codomain.__identifier__()) temp_id = (domain.__identifier__(), codomain.__identifier__())
...@@ -208,13 +213,13 @@ if fft_machine == 'pyfftw': ...@@ -208,13 +213,13 @@ if fft_machine == 'pyfftw':
self.plan_dict[temp_id]=_fftw_plan_and_info(domain,codomain,self,**kwargs) self.plan_dict[temp_id]=_fftw_plan_and_info(domain,codomain,self,**kwargs)
return self.plan_dict[temp_id] return self.plan_dict[temp_id]
def transform(self,field_val,domain,codomain,**kwargs): def transform(self,val,domain,codomain, field_val, **kwargs):
""" """
The pyfftw transform function. The pyfftw transform function.
Parameters Parameters
---------- ----------
field_val : numpy.ndarray field_val : distributed_data_object
The value-array of the field which is supposed to The value-array of the field which is supposed to
be transformed. be transformed.
...@@ -234,15 +239,34 @@ if fft_machine == 'pyfftw': ...@@ -234,15 +239,34 @@ if fft_machine == 'pyfftw':
""" """
current_plan_and_info=self._get_plan_and_info(domain,codomain,**kwargs) current_plan_and_info=self._get_plan_and_info(domain,codomain,**kwargs)
## Prepare the input data ## Prepare the input data
field_val*=current_plan_and_info.get_codomain_centering_mask()
local_size = current_plan_and_info.fftw_local_size
local_start = local_size[2]
local_end = local_start + local_size[1]
val = field_val.get_data(slice(local_start,local_end))
val *= current_plan_and_info.get_codomain_centering_mask()
## Define a abbreviation for the fftw plan ## Define a abbreviation for the fftw plan
p = current_plan_and_info.get_plan() p = current_plan_and_info.get_plan()
## load the field into the plan ## load the field into the plan
if p.has_input: if p.has_input:
p.input_array[:] = field_val p.input_array[:] = val
## execute the plan ## execute the plan
p() p()
return p.output_array*current_plan_and_info.get_domain_centering_mask() result = p.output_array*current_plan_and_info.get_domain_centering_mask()
## renorm the result according to the convention of gfft
if current_plan_and_info.direction == 'FFTW_FORWARD':
result = result/float(result.size)
else:
result *= float(result.size)
## build a distributed_data_object
data_object = nifty_mpi_data.distributed_data_object(global_shape = current_plan_and_info.global_output_shape, dtype = np.complex128, distribution_strategy='fftw')
data_object.set_local_data(data=result)
return data_object.get_full_data()
elif fft_machine == 'gfft' or 'gfft_fallback': elif fft_machine == 'gfft' or 'gfft_fallback':
...@@ -255,13 +279,13 @@ elif fft_machine == 'gfft' or 'gfft_fallback': ...@@ -255,13 +279,13 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
None None
""" """
def transform(self,field_val,domain,codomain,**kwargs): def transform(self, val, domain, codomain, **kwargs):
""" """
The gfft transform function. The gfft transform function.
Parameters Parameters
---------- ----------
field_val : numpy.ndarray val : numpy.ndarray
The value-array of the field which is supposed to The value-array of the field which is supposed to
be transformed. be transformed.
...@@ -286,7 +310,7 @@ elif fft_machine == 'gfft' or 'gfft_fallback': ...@@ -286,7 +310,7 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
ftmachine = "ifft" ftmachine = "ifft"
## transform and return ## transform and return
if(domain.datatype==np.float64): if(domain.datatype==np.float64):
return gfft.gfft(field_val.astype(np.complex128),in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False) return gfft.gfft(val.astype(np.complex128),in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
else: else:
return gfft.gfft(field_val,in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False) return gfft.gfft(val,in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
\ No newline at end of file
...@@ -42,6 +42,7 @@ from nifty.nifty_core import about, \ ...@@ -42,6 +42,7 @@ from nifty.nifty_core import about, \
random, \ random, \
space, \ space, \
field field
import nifty.nifty_mpi_data
import nifty.smoothing as gs import nifty.smoothing as gs
import powerspectrum as gp import powerspectrum as gp
''' '''
...@@ -204,7 +205,7 @@ class rg_space(space): ...@@ -204,7 +205,7 @@ class rg_space(space):
self.fourier = bool(fourier) self.fourier = bool(fourier)
## Initializes the fast-fourier-transform machine, which will be used ## Initializes the fast-fourier-transform machine, which will be used
## to transform the spaace ## to transform the space
self.fft_machine = fft_rg.fft_factory() self.fft_machine = fft_rg.fft_factory()
...@@ -823,11 +824,9 @@ class rg_space(space): ...@@ -823,11 +824,9 @@ class rg_space(space):
## of transformation is infered from the fourier attribute of the ## of transformation is infered from the fourier attribute of the
## supplied space ## supplied space
if(codomain.fourier): if(codomain.fourier):
#ftmachine = "fft"
## correct for 'fft' ## correct for 'fft'
x = self.calc_weight(x,power=1) x = self.calc_weight(x,power=1)
else: else:
#ftmachine = "ifft"
## correct for 'ifft' ## correct for 'ifft'
x = self.calc_weight(x,power=1) x = self.calc_weight(x,power=1)
x *= self.dim(split=False) x *= self.dim(split=False)
...@@ -837,7 +836,7 @@ class rg_space(space): ...@@ -837,7 +836,7 @@ class rg_space(space):
#ftmachine = "none" #ftmachine = "none"
## transform ## transform
Tx = self.fft_machine.transform(x,self,codomain) Tx = self.fft_machine.transform(x,self,codomain,**kwargs)
## check complexity ## check complexity
if(not codomain.para[naxes]): ## purely real if(not codomain.para[naxes]): ## purely real
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment