Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
0f24d97e
Commit
0f24d97e
authored
Mar 23, 2015
by
ultimanet
Browse files
added distributed_data_object
parent
47ecedc4
Changes
4
Hide whitespace changes
Inline
Side-by-side
__init__.py
View file @
0f24d97e
...
...
@@ -25,6 +25,7 @@ from nifty_cmaps import *
from
nifty_power
import
*
from
nifty_tools
import
*
from
nifty_explicit
import
*
from
nifty_mpi_data
import
distributed_data_object
## optional submodule `rg`
try
:
...
...
@@ -42,3 +43,4 @@ except(ImportError):
from
demos
import
*
from
pickling
import
*
#import pyximport; pyximport.install(pyimport = True)
\ No newline at end of file
nifty_core.py
View file @
0f24d97e
...
...
@@ -148,7 +148,7 @@ import pylab as pl
from
multiprocessing
import
Pool
as
mp
from
multiprocessing
import
Value
as
mv
from
multiprocessing
import
Array
as
ma
from
nifty_mpi_data
import
distributed_data_object
__version__
=
"1.0.6"
...
...
@@ -4983,12 +4983,24 @@ class field(object):
else
:
self
.
domain
.
check_codomain
(
target
)
self
.
target
=
target
self
.
distributed_val
=
distributed_data_object
(
global_shape
=
domain
.
dim
(
split
=
True
),
dtype
=
domain
.
datatype
)
## check values
if
(
val
is
None
):
self
.
val
=
self
.
domain
.
get_random_values
(
codomain
=
self
.
target
,
**
kwargs
)
else
:
self
.
val
=
self
.
domain
.
enforce_values
(
val
,
extend
=
True
)
@
property
def
val
(
self
):
return
self
.
distributed_val
.
get_full_data
()
#return self.distributed_val
@
val
.
setter
def
val
(
self
,
x
):
return
self
.
distributed_val
.
set_full_data
(
x
)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def
dim
(
self
,
split
=
False
):
...
...
@@ -5357,11 +5369,11 @@ class field(object):
else
:
self
.
domain
.
check_codomain
(
target
)
## a bit pointless
if
(
overwrite
):
self
.
val
=
self
.
domain
.
calc_transform
(
self
.
val
,
codomain
=
target
,
**
kwargs
)
self
.
val
=
self
.
domain
.
calc_transform
(
self
.
val
,
codomain
=
target
,
field_val
=
self
.
distributed_val
,
**
kwargs
)
self
.
target
=
self
.
domain
self
.
domain
=
target
else
:
return
field
(
target
,
val
=
self
.
domain
.
calc_transform
(
self
.
val
,
codomain
=
target
,
**
kwargs
),
target
=
self
.
domain
)
return
field
(
target
,
val
=
self
.
domain
.
calc_transform
(
self
.
val
,
codomain
=
target
,
field_val
=
self
.
distributed_val
,
**
kwargs
),
target
=
self
.
domain
)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
...
...
rg/fft_rg.py
View file @
0f24d97e
# -*- coding: utf-8 -*-
import
numpy
as
np
from
nifty
import
nifty_mpi_data
# Try to import pyfftw. If this fails fall back to gfft. If this fails fall back to local gfft_rg
...
...
@@ -54,7 +55,7 @@ class fft(object):
Parameters
----------
field_val :
numpy.ndarray
field_val :
distributed_data_object
The value-array of the field which is supposed to
be transformed.
...
...
@@ -140,8 +141,14 @@ if fft_machine == 'pyfftw':
None
"""
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict
=
{}
def
__init__
(
self
):
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
self
.
plan_dict
=
{}
## initialize the dictionary which stores the values from get_centering_mask
self
.
centering_mask_dict
=
{}
def
get_centering_mask
(
self
,
to_center_input
,
dimensions_input
,
offset_input
=
0
):
"""
Computes the mask, used to (de-)zerocenter domain and target
...
...
@@ -197,9 +204,7 @@ if fft_machine == 'pyfftw':
self
.
centering_mask_dict
[
temp_id
]
=
centering_mask
return
self
.
centering_mask_dict
[
temp_id
]
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
plan_dict
=
{}
def
_get_plan_and_info
(
self
,
domain
,
codomain
,
**
kwargs
):
## generate a id-tuple which identifies the domain-codomain setting
temp_id
=
(
domain
.
__identifier__
(),
codomain
.
__identifier__
())
...
...
@@ -208,13 +213,13 @@ if fft_machine == 'pyfftw':
self
.
plan_dict
[
temp_id
]
=
_fftw_plan_and_info
(
domain
,
codomain
,
self
,
**
kwargs
)
return
self
.
plan_dict
[
temp_id
]
def
transform
(
self
,
field_
val
,
domain
,
codomain
,
**
kwargs
):
def
transform
(
self
,
val
,
domain
,
codomain
,
field_val
,
**
kwargs
):
"""
The pyfftw transform function.
Parameters
----------
field_val :
numpy.ndarray
field_val :
distributed_data_object
The value-array of the field which is supposed to
be transformed.
...
...
@@ -234,15 +239,34 @@ if fft_machine == 'pyfftw':
"""
current_plan_and_info
=
self
.
_get_plan_and_info
(
domain
,
codomain
,
**
kwargs
)
## Prepare the input data
field_val
*=
current_plan_and_info
.
get_codomain_centering_mask
()
local_size
=
current_plan_and_info
.
fftw_local_size
local_start
=
local_size
[
2
]
local_end
=
local_start
+
local_size
[
1
]
val
=
field_val
.
get_data
(
slice
(
local_start
,
local_end
))
val
*=
current_plan_and_info
.
get_codomain_centering_mask
()
## Define a abbreviation for the fftw plan
p
=
current_plan_and_info
.
get_plan
()
## load the field into the plan
if
p
.
has_input
:
p
.
input_array
[:]
=
field_
val
p
.
input_array
[:]
=
val
## execute the plan
p
()
return
p
.
output_array
*
current_plan_and_info
.
get_domain_centering_mask
()
result
=
p
.
output_array
*
current_plan_and_info
.
get_domain_centering_mask
()
## renorm the result according to the convention of gfft
if
current_plan_and_info
.
direction
==
'FFTW_FORWARD'
:
result
=
result
/
float
(
result
.
size
)
else
:
result
*=
float
(
result
.
size
)
## build a distributed_data_object
data_object
=
nifty_mpi_data
.
distributed_data_object
(
global_shape
=
current_plan_and_info
.
global_output_shape
,
dtype
=
np
.
complex128
,
distribution_strategy
=
'fftw'
)
data_object
.
set_local_data
(
data
=
result
)
return
data_object
.
get_full_data
()
elif
fft_machine
==
'gfft'
or
'gfft_fallback'
:
...
...
@@ -255,13 +279,13 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
None
"""
def
transform
(
self
,
field_
val
,
domain
,
codomain
,
**
kwargs
):
def
transform
(
self
,
val
,
domain
,
codomain
,
**
kwargs
):
"""
The gfft transform function.
Parameters
----------
field_
val : numpy.ndarray
val : numpy.ndarray
The value-array of the field which is supposed to
be transformed.
...
...
@@ -286,7 +310,7 @@ elif fft_machine == 'gfft' or 'gfft_fallback':
ftmachine
=
"ifft"
## transform and return
if
(
domain
.
datatype
==
np
.
float64
):
return
gfft
.
gfft
(
field_
val
.
astype
(
np
.
complex128
),
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
return
gfft
.
gfft
(
val
.
astype
(
np
.
complex128
),
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
else
:
return
gfft
.
gfft
(
field_
val
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
return
gfft
.
gfft
(
val
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
\ No newline at end of file
rg/nifty_rg.py
View file @
0f24d97e
...
...
@@ -42,6 +42,7 @@ from nifty.nifty_core import about, \
random
,
\
space
,
\
field
import
nifty.nifty_mpi_data
import
nifty.smoothing
as
gs
import
powerspectrum
as
gp
'''
...
...
@@ -204,7 +205,7 @@ class rg_space(space):
self
.
fourier
=
bool
(
fourier
)
## Initializes the fast-fourier-transform machine, which will be used
## to transform the spa
a
ce
## to transform the space
self
.
fft_machine
=
fft_rg
.
fft_factory
()
...
...
@@ -823,11 +824,9 @@ class rg_space(space):
## of transformation is infered from the fourier attribute of the
## supplied space
if
(
codomain
.
fourier
):
#ftmachine = "fft"
## correct for 'fft'
x
=
self
.
calc_weight
(
x
,
power
=
1
)
else
:
#ftmachine = "ifft"
## correct for 'ifft'
x
=
self
.
calc_weight
(
x
,
power
=
1
)
x
*=
self
.
dim
(
split
=
False
)
...
...
@@ -837,7 +836,7 @@ class rg_space(space):
#ftmachine = "none"
## transform
Tx
=
self
.
fft_machine
.
transform
(
x
,
self
,
codomain
)
Tx
=
self
.
fft_machine
.
transform
(
x
,
self
,
codomain
,
**
kwargs
)
## check complexity
if
(
not
codomain
.
para
[
naxes
]):
## purely real
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment