Commit 0f24d97e authored by ultimanet's avatar ultimanet
Browse files

added distributed_data_object

parent 47ecedc4
......@@ -25,6 +25,7 @@ from nifty_cmaps import *
from nifty_power import *
from nifty_tools import *
from nifty_explicit import *
from nifty_mpi_data import distributed_data_object
## optional submodule `rg`
try:
......@@ -42,3 +43,4 @@ except(ImportError):
from demos import *
from pickling import *
#import pyximport; pyximport.install(pyimport = True)
\ No newline at end of file
......@@ -148,7 +148,7 @@ import pylab as pl
from multiprocessing import Pool as mp
from multiprocessing import Value as mv
from multiprocessing import Array as ma
from nifty_mpi_data import distributed_data_object
__version__ = "1.0.6"
......@@ -4983,12 +4983,24 @@ class field(object):
else:
self.domain.check_codomain(target)
self.target = target
self.distributed_val = distributed_data_object(global_shape=domain.dim(split=True), dtype=domain.datatype)
## check values
if(val is None):
self.val = self.domain.get_random_values(codomain=self.target,**kwargs)
else:
self.val = self.domain.enforce_values(val,extend=True)
@property
def val(self):
return self.distributed_val.get_full_data()
#return self.distributed_val
@val.setter
def val(self, x):
return self.distributed_val.set_full_data(x)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dim(self,split=False):
......@@ -5357,11 +5369,11 @@ class field(object):
else:
self.domain.check_codomain(target) ## a bit pointless
if(overwrite):
self.val = self.domain.calc_transform(self.val,codomain=target,**kwargs)
self.val = self.domain.calc_transform(self.val,codomain=target,field_val=self.distributed_val, **kwargs)
self.target = self.domain
self.domain = target
else:
return field(target,val=self.domain.calc_transform(self.val,codomain=target,**kwargs),target=self.domain)
return field(target,val=self.domain.calc_transform(self.val,codomain=target, field_val=self.distributed_val, **kwargs),target=self.domain)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
# -*- coding: utf-8 -*-
import numpy as np
from nifty import nifty_mpi_data
# Try to import pyfftw. If this fails fall back to gfft. If this fails fall back to local gfft_rg
......@@ -54,7 +55,7 @@ class fft(object):
Parameters
----------
field_val : numpy.ndarray
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
......@@ -140,8 +141,14 @@ if fft_machine == 'pyfftw':
None
"""
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict = {}
def __init__(self):
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
self.plan_dict = {}
## initialize the dictionary which stores the values from get_centering_mask
self.centering_mask_dict = {}
def get_centering_mask(self, to_center_input, dimensions_input, offset_input=0):
"""
Computes the mask, used to (de-)zerocenter domain and target
......@@ -197,9 +204,7 @@ if fft_machine == 'pyfftw':
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
plan_dict = {}
def _get_plan_and_info(self,domain,codomain,**kwargs):
## generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__identifier__(), codomain.__identifier__())
......@@ -208,13 +213,13 @@ if fft_machine == 'pyfftw':
self.plan_dict[temp_id]=_fftw_plan_and_info(domain,codomain,self,**kwargs)
return self.plan_dict[temp_id]
def transform(self,field_val,domain,codomain,**kwargs):
def transform(self,val,domain,codomain, field_val, **kwargs):
"""
The pyfftw transform function.
Parameters
----------
field_val : numpy.ndarray
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
......@@ -234,15 +239,34 @@ if fft_machine == 'pyfftw':
"""
current_plan_and_info=self._get_plan_and_info(domain,codomain,**kwargs)
## Prepare the input data
field_val*=current_plan_and_info.get_codomain_centering_mask()
local_size = current_plan_and_info.fftw_local_size
local_start = local_size[2]
local_end = local_start + local_size[1]
val = field_val.get_data(slice(local_start,local_end))
val *= current_plan_and_info.get_codomain_centering_mask()
## Define a abbreviation for the fftw plan
p = current_plan_and_info.get_plan()
## load the field into the plan
if p.has_input:
p.input_array[:] = field_val
p.input_array[:] = val
## execute the plan
p()
return p.output_array*current_plan_and_info.get_domain_centering_mask()
result = p.output_array*current_plan_and_info.get_domain_centering_mask()
## renorm the result according to the convention of gfft
if current_plan_and_info.direction == 'FFTW_FORWARD':
result = result/float(result.size)
else:
result *= float(result.size)
## build a distributed_data_object
data_object = nifty_mpi_data.distributed_data_object(global_shape = current_plan_and_info.global_output_shape, dtype = np.complex128, distribution_strategy='fftw')
data_object.set_local_data(data=result)
return data_object.get_full_data()
elif fft_machine == 'gfft' or 'gfft_fallback':
......@@ -255,13 +279,13 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
None
"""
def transform(self,field_val,domain,codomain,**kwargs):
def transform(self, val, domain, codomain, **kwargs):
"""
The gfft transform function.
Parameters
----------
field_val : numpy.ndarray
val : numpy.ndarray
The value-array of the field which is supposed to
be transformed.
......@@ -286,7 +310,7 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
ftmachine = "ifft"
## transform and return
if(domain.datatype==np.float64):
return gfft.gfft(field_val.astype(np.complex128),in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
return gfft.gfft(val.astype(np.complex128),in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
else:
return gfft.gfft(field_val,in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
return gfft.gfft(val,in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
\ No newline at end of file
......@@ -42,6 +42,7 @@ from nifty.nifty_core import about, \
random, \
space, \
field
import nifty.nifty_mpi_data
import nifty.smoothing as gs
import powerspectrum as gp
'''
......@@ -204,7 +205,7 @@ class rg_space(space):
self.fourier = bool(fourier)
## Initializes the fast-fourier-transform machine, which will be used
## to transform the spaace
## to transform the space
self.fft_machine = fft_rg.fft_factory()
......@@ -823,11 +824,9 @@ class rg_space(space):
## of transformation is infered from the fourier attribute of the
## supplied space
if(codomain.fourier):
#ftmachine = "fft"
## correct for 'fft'
x = self.calc_weight(x,power=1)
else:
#ftmachine = "ifft"
## correct for 'ifft'
x = self.calc_weight(x,power=1)
x *= self.dim(split=False)
......@@ -837,7 +836,7 @@ class rg_space(space):
#ftmachine = "none"
## transform
Tx = self.fft_machine.transform(x,self,codomain)
Tx = self.fft_machine.transform(x,self,codomain,**kwargs)
## check complexity
if(not codomain.para[naxes]): ## purely real
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment