Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
83ae9e94
Commit
83ae9e94
authored
Jul 20, 2016
by
Jait Dixit
Browse files
Fix issues
#50
and
#52
parent
69fa111d
Changes
11
Show whitespace changes
Inline
Side-by-side
test/test_nifty_transforms.py
View file @
83ae9e94
...
...
@@ -7,7 +7,7 @@ import itertools
from
nifty
import
RGSpace
,
LMSpace
,
HPSpace
,
GLSpace
from
nifty
import
transformator
import
nifty.transformations.transformation
as
t
ransformation
from
nifty.transformations.
rgrg
transformation
import
RGRGT
ransformation
from
nifty.rg.rg_space
import
gc
as
RG_GC
import
d2o
...
...
@@ -72,8 +72,8 @@ class TestRGRGTransformation(unittest.TestCase):
def
test_check_codomain_rgspecific
(
self
,
complexity
,
distances
,
harmonic
):
x
=
RGSpace
((
8
,
8
),
complexity
=
complexity
,
distances
=
distances
,
harmonic
=
harmonic
)
assert
(
transformation
.
RGRGTransformation
.
check_codomain
(
x
,
x
.
get_codomain
()))
assert
(
transformation
.
RGRGTransformation
.
check_codomain
(
x
,
x
.
get_codomain
()))
assert
(
RGRGTransformation
.
check_codomain
(
x
,
x
.
get_codomain
()))
assert
(
RGRGTransformation
.
check_codomain
(
x
,
x
.
get_codomain
()))
@
parameterized
.
expand
(
rg_rg_fft_modules
,
testcase_func_name
=
custom_name_func
)
def
test_shapemismatch
(
self
,
module
):
...
...
transformations/gfft.py
deleted
100644 → 0
View file @
69fa111d
import
numpy
as
np
from
transform
import
Transform
from
d2o
import
distributed_data_object
import
nifty.nifty_utilities
as
utilities
class
GFFT
(
Transform
):
"""
The gfft pendant of a fft object.
Parameters
----------
fft_module_name : String
Switch between the gfft module used: 'gfft' and 'gfft_dummy'
"""
def
__init__
(
self
,
domain
,
codomain
,
fft_module
):
self
.
domain
=
domain
self
.
codomain
=
codomain
self
.
fft_machine
=
fft_module
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
"""
The gfft transform function.
Parameters
----------
val : numpy.ndarray or distributed_data_object
The value-array of the field which is supposed to
be transformed.
axes : None or tuple
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
Returns
-------
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
not
all
(
axis
in
range
(
len
(
val
.
shape
))
for
axis
in
axes
):
raise
ValueError
(
"ERROR: Provided axes does not match array shape"
)
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if
isinstance
(
val
,
distributed_data_object
):
temp_inp
=
val
.
get_full_data
()
else
:
temp_inp
=
val
# Array for storing the result
return_val
=
None
for
slice_list
in
utilities
.
get_slice_list
(
temp_inp
.
shape
,
axes
):
# don't copy the whole data array
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
temp_inp
else
:
# initialize the return_val object if needed
if
return_val
is
None
:
return_val
=
np
.
empty_like
(
temp_inp
)
inp
=
temp_inp
[
slice_list
]
inp
=
self
.
fft_machine
.
gfft
(
inp
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
'fft'
if
self
.
codomain
.
harmonic
else
'ifft'
,
in_zero_center
=
map
(
bool
,
self
.
domain
.
paradict
[
'zerocenter'
]
),
out_zero_center
=
map
(
bool
,
self
.
codomain
.
paradict
[
'zerocenter'
]
),
enforce_hermitian_symmetry
=
bool
(
self
.
codomain
.
paradict
[
'complexity'
]
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
if
slice_list
==
[
slice
(
None
,
None
)]:
return_val
=
inp
else
:
return_val
[
slice_list
]
=
inp
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if
self
.
domain
.
paradict
[
'complexity'
]
==
0
:
new_val
.
hermitian
=
True
return_val
=
new_val
else
:
return_val
=
return_val
.
astype
(
self
.
codomain
.
dtype
,
copy
=
False
)
return
return_val
transformations/gltransform.py
→
transformations/gl
lm
transform
ation
.py
View file @
83ae9e94
import
numpy
as
np
from
transform
import
Transform
from
transform
ation
import
Transform
ation
from
d2o
import
distributed_data_object
from
nifty.config
import
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
nifty
import
GLSpace
,
LMSpace
gl
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
class
GLTransform
(
Transform
):
"""
GLTransform wrapper for libsharp's transform functions
"""
class
GL
LM
Transform
ation
(
Transform
ation
):
def
__init__
(
self
,
domain
,
codomain
,
module
=
None
):
if
'libsharp_wrapper_gl'
not
in
gdi
:
raise
ImportError
(
"The module libsharp is needed but not available"
)
def
__init__
(
self
,
domain
,
codomain
):
if
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
if
'libsharp_wrapper_gl'
not
in
gdi
:
raise
ImportError
(
"The module libsharp_wrapper_gl "
+
"is needed but not available"
)
@
staticmethod
def
check_codomain
(
domain
,
codomain
):
if
not
isinstance
(
domain
,
GLSpace
):
raise
TypeError
(
'ERROR: domain is not a GLSpace'
)
if
codomain
is
None
:
return
False
if
not
isinstance
(
codomain
,
LMSpace
):
raise
TypeError
(
'ERROR: codomain must be a LMSpace.'
)
nlat
=
domain
.
paradict
[
'nlat'
]
nlon
=
domain
.
paradict
[
'nlon'
]
lmax
=
codomain
.
paradict
[
'lmax'
]
mmax
=
codomain
.
paradict
[
'mmax'
]
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
if
(
nlon
!=
2
*
nlat
-
1
)
or
(
lmax
!=
nlat
-
1
)
or
(
lmax
!=
mmax
):
return
False
return
True
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
"""
GL -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if
self
.
domain
.
discrete
:
val
=
self
.
domain
.
calc_weight
(
val
,
power
=-
0.5
)
...
...
transformations/hptransform.py
→
transformations/hp
lm
transform
ation
.py
View file @
83ae9e94
import
numpy
as
np
from
transform
import
Transform
from
transform
ation
import
Transform
ation
from
d2o
import
distributed_data_object
from
nifty.config
import
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
nifty
import
HPSpace
,
LMSpace
hp
=
gdi
.
get
(
'healpy'
)
class
HPTransform
(
Transform
):
"""
GLTransform wrapper for libsharp's transform functions
"""
def
__init__
(
self
,
domain
,
codomain
):
class
HPLMTransformation
(
Transformation
):
def
__init__
(
self
,
domain
,
codomain
,
module
=
None
):
if
'healpy'
not
in
gdi
:
raise
ImportError
(
"The module healpy is needed but not available"
)
if
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
if
'healpy'
not
in
gdi
:
raise
ImportError
(
"The module healpy is needed but not available"
)
@
staticmethod
def
check_codomain
(
domain
,
codomain
):
if
not
isinstance
(
domain
,
HPSpace
):
raise
TypeError
(
'ERROR: domain is not a HPSpace'
)
if
codomain
is
None
:
return
False
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
if
not
isinstance
(
codomain
,
LMSpace
):
raise
TypeError
(
'ERROR: codomain must be a LMSpace.'
)
nside
=
domain
.
paradict
[
'nside'
]
lmax
=
codomain
.
paradict
[
'lmax'
]
mmax
=
codomain
.
paradict
[
'mmax'
]
if
(
3
*
nside
-
1
!=
lmax
)
or
(
lmax
!=
mmax
):
return
False
return
True
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
"""
HP -> LM transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
# get by number of iterations from kwargs
niter
=
kwargs
[
'niter'
]
if
'niter'
in
kwargs
else
0
...
...
transformations/lmtransform.py
→
transformations/lm
gl
transform
ation
.py
View file @
83ae9e94
import
numpy
as
np
from
nifty
import
GLSpace
,
HPSpace
from
nifty.config
import
about
import
nifty.nifty_utilities
as
utilities
from
transform
import
Transform
from
transformation
import
Transformation
from
d2o
import
distributed_data_object
from
nifty.config
import
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
nifty
import
GLSpace
,
LMSpace
gl
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
class
LMTransform
(
Transform
):
"""
LMTransform for transforming to GL/HP space
"""
def
__init__
(
self
,
domain
,
codomain
,
module
):
class
LMGLTransformation
(
Transformation
):
def
__init__
(
self
,
domain
,
codomain
,
module
=
None
):
if
gdi
.
get
(
'libsharp_wrapper_gl'
)
is
None
:
raise
ImportError
(
"The module libsharp is needed but not available."
)
if
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
self
.
module
=
module
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
def
_transform
(
self
,
val
):
if
isinstance
(
self
.
codomain
,
GLSpace
):
# shorthand for transform parameters
nlat
=
self
.
codomain
.
paradict
[
'nlat'
]
nlon
=
self
.
codomain
.
paradict
[
'nlon'
]
lmax
=
self
.
domain
.
paradict
[
'lmax'
]
mmax
=
self
.
paradict
[
'mmax'
]
@
staticmethod
def
check_codomain
(
domain
,
codomain
):
if
not
isinstance
(
domain
,
LMSpace
):
raise
TypeError
(
'ERROR: domain is not a LMSpace'
)
if
self
.
domain
.
dtype
==
np
.
dtype
(
'complex64'
):
val
=
self
.
module
.
alm2map_f
(
val
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
,
cl
=
False
)
else
:
val
=
self
.
module
.
alm2map
(
val
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
,
cl
=
False
)
elif
isinstance
(
self
.
codomain
,
HPSpace
):
# shorthand for transform parameters
nside
=
self
.
codomain
.
paradict
[
'nside'
]
lmax
=
self
.
domain
.
paradict
[
'lmax'
]
mmax
=
self
.
domain
.
paradict
[
'mmax'
]
if
codomain
is
None
:
return
False
val
=
val
.
astype
(
np
.
complex128
,
copy
=
False
)
val
=
self
.
module
.
alm2map
(
val
,
nside
,
lmax
=
lmax
,
mmax
=
mmax
,
pixwin
=
False
,
fwhm
=
0.0
,
sigma
=
None
,
pol
=
True
,
inplace
=
False
)
else
:
raise
ValueError
(
"ERROR: Unsupported transformation."
)
if
not
isinstance
(
codomain
,
GLSpace
):
raise
TypeError
(
'ERROR: codomain must be a GLSpace.'
)
return
val
nlat
=
codomain
.
paradict
[
'nlat'
]
nlon
=
codomain
.
paradict
[
'nlon'
]
lmax
=
domain
.
paradict
[
'lmax'
]
mmax
=
domain
.
paradict
[
'mmax'
]
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
if
(
lmax
!=
mmax
)
or
(
nlat
!=
lmax
+
1
)
or
(
nlon
!=
2
*
lmax
+
1
):
return
False
return
True
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
"""
LM -> GL transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if
isinstance
(
val
,
distributed_data_object
):
temp_val
=
val
.
get_full_data
()
else
:
...
...
@@ -60,7 +69,17 @@ class LMTransform(Transform):
return_val
=
np
.
empty_like
(
temp_val
)
inp
=
temp_val
[
slice_list
]
inp
=
self
.
_transform
(
inp
)
nlat
=
self
.
codomain
.
paradict
[
'nlat'
]
nlon
=
self
.
codomain
.
paradict
[
'nlon'
]
lmax
=
self
.
domain
.
paradict
[
'lmax'
]
mmax
=
self
.
paradict
[
'mmax'
]
if
self
.
domain
.
dtype
==
np
.
dtype
(
'complex64'
):
inp
=
gl
.
alm2map_f
(
inp
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
,
cl
=
False
)
else
:
inp
=
gl
.
alm2map
(
inp
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
,
cl
=
False
)
if
slice_list
==
[
slice
(
None
,
None
)]:
return_val
=
inp
...
...
transformations/lmhptransformation.py
0 → 100644
View file @
83ae9e94
import
numpy
as
np
from
transformation
import
Transformation
from
d2o
import
distributed_data_object
from
nifty.config
import
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
nifty
import
HPSpace
,
LMSpace
hp
=
gdi
.
get
(
'healpy'
)
class
LMHPTransformation
(
Transformation
):
def
__init__
(
self
,
domain
,
codomain
,
module
=
None
):
if
gdi
.
get
(
'healpy'
)
is
None
:
raise
ImportError
(
"The module libsharp is needed but not available."
)
if
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
@
staticmethod
def
check_codomain
(
domain
,
codomain
):
if
not
isinstance
(
domain
,
LMSpace
):
raise
TypeError
(
'ERROR: domain is not a LMSpace'
)
if
codomain
is
None
:
return
False
if
not
isinstance
(
codomain
,
HPSpace
):
raise
TypeError
(
'ERROR: codomain must be a HPSpace.'
)
nside
=
codomain
.
paradict
[
'nside'
]
lmax
=
domain
.
paradict
[
'lmax'
]
mmax
=
domain
.
paradict
[
'mmax'
]
if
(
lmax
!=
mmax
)
or
(
3
*
nside
-
1
!=
lmax
):
return
False
return
True
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
"""
LM -> HP transform method.
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
axes : None or tuple
The axes along which the transformation should take place
"""
if
isinstance
(
val
,
distributed_data_object
):
temp_val
=
val
.
get_full_data
()
else
:
temp_val
=
val
return_val
=
None
for
slice_list
in
utilities
.
get_slice_list
(
temp_val
.
shape
,
axes
):
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
temp_val
else
:
if
return_val
is
None
:
return_val
=
np
.
empty_like
(
temp_val
)
inp
=
temp_val
[
slice_list
]
nside
=
self
.
codomain
.
paradict
[
'nside'
]
lmax
=
self
.
domain
.
paradict
[
'lmax'
]
mmax
=
self
.
domain
.
paradict
[
'mmax'
]
inp
=
inp
.
astype
(
np
.
complex128
,
copy
=
False
)
inp
=
hp
.
alm2map
(
inp
,
nside
,
lmax
=
lmax
,
mmax
=
mmax
,
pixwin
=
False
,
fwhm
=
0.0
,
sigma
=
None
,
pol
=
True
,
inplace
=
False
)
if
slice_list
==
[
slice
(
None
,
None
)]:
return_val
=
inp
else
:
return_val
[
slice_list
]
=
inp
# re-weight if discrete
if
self
.
codomain
.
discrete
:
val
=
self
.
codomain
.
calc_weight
(
val
,
power
=
0.5
)
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
else
:
return_val
=
return_val
.
astype
(
self
.
codomain
.
dtype
,
copy
=
False
)
return
return_val
transformations/
fftw
.py
→
transformations/
rg_transforms
.py
View file @
83ae9e94
...
...
@@ -4,15 +4,39 @@ import numpy as np
from
d2o
import
distributed_data_object
,
STRATEGIES
from
nifty.config
import
about
,
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
transform
import
Transform
from
mpi4py
import
MPI
from
nifty
import
nifty_configuration
pyfftw
=
gdi
.
get
(
'pyfftw'
)
class
FFTW
(
Transform
):
class
Transform
(
object
):
"""
A generic fft object without any implementation.
"""
def
__init__
(
self
,
domain
,
codomain
):
pass
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise
NotImplementedError
class
FFTW
(
Transform
):
"""
The pyfftw pendant of a fft object.
"""
...
...
@@ -106,7 +130,8 @@ class FFTW(Transform):
args
+
offset
.
reshape
(
offset
.
shape
+
(
1
,)
*
(
np
.
array
(
args
).
ndim
-
1
)),
(
np
.
array
(
args
).
ndim
-
1
)),
1
)),
(
2
,)
*
to_center
.
size
)
# Cast the core to the smallest integers we can get
...
...
@@ -193,7 +218,7 @@ class FFTW(Transform):
def
_atomic_mpi_transform
(
self
,
val
,
info
,
axes
):
# Apply codomain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
codomain
.
paradict
[
'zerocenter'
]):
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
codomain
.
paradict
[
'zerocenter'
]):
temp_val
=
np
.
copy
(
val
)
val
=
self
.
_apply_mask
(
temp_val
,
info
.
cmask_codomain
,
axes
)
...
...
@@ -210,7 +235,7 @@ class FFTW(Transform):
return
None
# Apply domain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain
.
paradict
[
'zerocenter'
]):
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain
.
paradict
[
'zerocenter'
]):
result
=
self
.
_apply_mask
(
result
,
info
.
cmask_domain
,
axes
)
# Correct the sign if needed
...
...
@@ -238,7 +263,7 @@ class FFTW(Transform):
**
kwargs
)
# Apply codomain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
codomain
.
paradict
[
'zerocenter'
]):
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
codomain
.
paradict
[
'zerocenter'
]):
temp_val
=
np
.
copy
(
local_val
)
local_val
=
self
.
_apply_mask
(
temp_val
,
current_info
.
cmask_codomain
,
axes
)
...
...
@@ -250,7 +275,7 @@ class FFTW(Transform):
)
# Apply domain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain
.
paradict
[
'zerocenter'
]):
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain
.
paradict
[
'zerocenter'
]):
local_result
=
self
.
_apply_mask
(
local_result
,
current_info
.
cmask_domain
,
axes
)
...
...
@@ -297,7 +322,6 @@ class FFTW(Transform):
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
self
.
codomain
.
dtype
)
# Extract local data
local_val
=
val
.
get_local_data
(
copy
=
False
)
...
...
@@ -335,7 +359,8 @@ class FFTW(Transform):
local_shape
=
val
.
local_shape
,
local_offset_Q
=
local_offset_Q
,
is_local
=
False
,
transform_shape
=
val
.
shape
,
# TODO: check why inp.shape doesn't work
transform_shape
=
val
.
shape
,
# TODO: check why inp.shape doesn't work
**
kwargs
)
...
...
@@ -420,7 +445,6 @@ class FFTW(Transform):
class
FFTWTransformInfo
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
if
pyfftw
is
None
:
...
...
@@ -468,7 +492,6 @@ class FFTWTransformInfo(object):
class
FFTWLocalTransformInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
super
(
FFTWLocalTransformInfo
,
self
).
__init__
(
domain
,
...
...
@@ -493,7 +516,6 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class
FFTWMPITransfromInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
transform_shape
,
**
kwargs
):
super
(
FFTWMPITransfromInfo
,
self
).
__init__
(
domain
,
...
...
@@ -519,3 +541,107 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
def
plan
(
self
,
plan
):
about
.
warnings
.
cprint
(
'WARNING: FFTWMPITransfromInfo plan
\
cannot be modified'
)
class
GFFT
(
Transform
):
"""
The gfft pendant of a fft object.
Parameters