Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
64a27773
Commit
64a27773
authored
May 01, 2017
by
Theo Steininger
Browse files
Small consistency fixes for spaces and FFTOperators.
parent
c6c1bbb9
Pipeline
#11884
passed with stage
in 10 minutes and 16 seconds
Changes
17
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/domain_object.py
View file @
64a27773
...
...
@@ -18,8 +18,6 @@
import
abc
import
numpy
as
np
from
keepers
import
Loggable
,
\
Versionable
...
...
nifty/field.py
View file @
64a27773
...
...
@@ -45,8 +45,7 @@ class Field(Loggable, Versionable, object):
self
.
domain_axes
=
self
.
_get_axes_tuple
(
self
.
domain
)
self
.
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
,
domain
=
self
.
domain
)
val
=
val
)
self
.
distribution_strategy
=
self
.
_parse_distribution_strategy
(
distribution_strategy
=
distribution_strategy
,
...
...
@@ -86,16 +85,17 @@ class Field(Loggable, Versionable, object):
axes_list
+=
[
tuple
(
l
)]
return
tuple
(
axes_list
)
def
_infer_dtype
(
self
,
dtype
,
val
,
domain
):
def
_infer_dtype
(
self
,
dtype
,
val
):
if
dtype
is
None
:
if
isinstance
(
val
,
Field
)
or
\
isinstance
(
val
,
distributed_data_object
):
try
:
dtype
=
val
.
dtype
dtype_tuple
=
(
np
.
dtype
(
gc
[
'default_field_dtype'
]),)
except
AttributeError
:
if
val
is
not
None
:
dtype
=
np
.
result_type
(
val
)
else
:
dtype
=
np
.
dtype
(
gc
[
'default_field_dtype'
])
else
:
dtype_tuple
=
(
np
.
dtype
(
dtype
),)
dtype
=
reduce
(
lambda
x
,
y
:
np
.
result_type
(
x
,
y
),
dtype_tuple
)
dtype
=
np
.
dtype
(
dtype
)
return
dtype
...
...
nifty/operators/fft_operator/fft_operator.py
View file @
64a27773
...
...
@@ -83,12 +83,14 @@ class FFTOperator(LinearOperator):
# Store the dtype information
if
domain_dtype
is
None
:
self
.
domain_dtype
=
None
self
.
logger
.
info
(
"Setting domain_dtype to np.float."
)
self
.
domain_dtype
=
np
.
float
else
:
self
.
domain_dtype
=
np
.
dtype
(
domain_dtype
)
if
target_dtype
is
None
:
self
.
target_dtype
=
None
self
.
logger
.
info
(
"Setting target_dtype to np.complex."
)
self
.
target_dtype
=
np
.
complex
else
:
self
.
target_dtype
=
np
.
dtype
(
target_dtype
)
...
...
nifty/operators/fft_operator/transformations/gllmtransformation.py
View file @
64a27773
...
...
@@ -21,7 +21,7 @@ import numpy as np
from
nifty.config
import
dependency_injector
as
gdi
from
nifty
import
GLSpace
,
LMSpace
from
slicing_transformation
import
SlicingTransformation
import
lm_transformation_
factory
import
lm_transformation_
helper
pyHealpix
=
gdi
.
get
(
'pyHealpix'
)
...
...
@@ -31,12 +31,17 @@ class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
module
is
None
:
module
=
'pyHealpix'
if
module
!=
'pyHealpix'
:
raise
ValueError
(
"Unsupported SHT module."
)
if
'pyHealpix'
not
in
gdi
:
raise
ImportError
(
"The module pyHealpix is needed but not available."
)
super
(
GLLMTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
=
module
)
super
(
GLLMTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
# ---Mandatory properties and methods---
...
...
@@ -63,7 +68,7 @@ class GLLMTransformation(SlicingTransformation):
nlat
=
domain
.
nlat
lmax
=
nlat
-
1
result
=
LMSpace
(
lmax
=
lmax
,
dtype
=
domain
.
dtype
)
result
=
LMSpace
(
lmax
=
lmax
)
return
result
@
classmethod
...
...
@@ -91,6 +96,11 @@ class GLLMTransformation(SlicingTransformation):
super
(
GLLMTransformation
,
cls
).
check_codomain
(
domain
,
codomain
)
def
_transformation_of_slice
(
self
,
inp
,
**
kwargs
):
if
inp
.
dtype
not
in
(
np
.
float
,
np
.
complex
):
self
.
logger
.
warn
(
"The input array has dtype: %s. The FFT will "
"be performed at double precision."
%
str
(
inp
.
dtype
))
nlat
=
self
.
domain
.
nlat
nlon
=
self
.
domain
.
nlon
lmax
=
self
.
codomain
.
lmax
...
...
@@ -104,12 +114,12 @@ class GLLMTransformation(SlicingTransformation):
for
x
in
(
inp
.
real
,
inp
.
imag
)]
[
resultReal
,
resultImag
]
=
[
lm_transformation_
factory
.
buildIdx
(
x
,
lmax
=
lmax
)
resultImag
]
=
[
lm_transformation_
helper
.
buildIdx
(
x
,
lmax
=
lmax
)
for
x
in
[
resultReal
,
resultImag
]]
result
=
self
.
_combine_complex_result
(
resultReal
,
resultImag
)
else
:
result
=
sjob
.
map2alm
(
inp
)
result
=
lm_transformation_
factory
.
buildIdx
(
result
,
lmax
=
lmax
)
result
=
lm_transformation_
helper
.
buildIdx
(
result
,
lmax
=
lmax
)
return
result
nifty/operators/fft_operator/transformations/hplmtransformation.py
View file @
64a27773
...
...
@@ -22,7 +22,7 @@ from nifty.config import dependency_injector as gdi
from
nifty
import
HPSpace
,
LMSpace
from
slicing_transformation
import
SlicingTransformation
import
lm_transformation_
factory
import
lm_transformation_
helper
pyHealpix
=
gdi
.
get
(
'pyHealpix'
)
...
...
@@ -32,12 +32,17 @@ class HPLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
module
is
None
:
module
=
'pyHealpix'
if
module
!=
'pyHealpix'
:
raise
ValueError
(
"Unsupported SHT module."
)
if
'pyHealpix'
not
in
gdi
:
raise
ImportError
(
"The module pyHealpix is needed but not available"
)
super
(
HPLMTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
=
module
)
super
(
HPLMTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
# ---Mandatory properties and methods---
...
...
@@ -63,7 +68,7 @@ class HPLMTransformation(SlicingTransformation):
lmax
=
2
*
domain
.
nside
result
=
LMSpace
(
lmax
=
lmax
,
dtype
=
domain
.
dtype
)
result
=
LMSpace
(
lmax
=
lmax
)
return
result
@
classmethod
...
...
@@ -83,6 +88,11 @@ class HPLMTransformation(SlicingTransformation):
super
(
HPLMTransformation
,
cls
).
check_codomain
(
domain
,
codomain
)
def
_transformation_of_slice
(
self
,
inp
,
**
kwargs
):
if
inp
.
dtype
not
in
(
np
.
float
,
np
.
complex
):
self
.
logger
.
warn
(
"The input array has dtype: %s. The FFT will "
"be performed at double precision."
%
str
(
inp
.
dtype
))
lmax
=
self
.
codomain
.
lmax
mmax
=
lmax
...
...
@@ -92,13 +102,13 @@ class HPLMTransformation(SlicingTransformation):
for
x
in
(
inp
.
real
,
inp
.
imag
)]
[
resultReal
,
resultImag
]
=
[
lm_transformation_
factory
.
buildIdx
(
x
,
lmax
=
lmax
)
resultImag
]
=
[
lm_transformation_
helper
.
buildIdx
(
x
,
lmax
=
lmax
)
for
x
in
[
resultReal
,
resultImag
]]
result
=
self
.
_combine_complex_result
(
resultReal
,
resultImag
)
else
:
result
=
pyHealpix
.
map2alm_iter
(
inp
,
lmax
,
mmax
,
3
)
result
=
lm_transformation_
factory
.
buildIdx
(
result
,
lmax
=
lmax
)
result
=
lm_transformation_
helper
.
buildIdx
(
result
,
lmax
=
lmax
)
return
result
nifty/operators/fft_operator/transformations/lm_transformation_
factory
.py
→
nifty/operators/fft_operator/transformations/lm_transformation_
helper
.py
View file @
64a27773
File moved
nifty/operators/fft_operator/transformations/lmgltransformation.py
View file @
64a27773
...
...
@@ -21,7 +21,7 @@ from nifty.config import dependency_injector as gdi
from
nifty
import
GLSpace
,
LMSpace
from
slicing_transformation
import
SlicingTransformation
import
lm_transformation_
factory
import
lm_transformation_
helper
pyHealpix
=
gdi
.
get
(
'pyHealpix'
)
...
...
@@ -31,12 +31,17 @@ class LMGLTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
module
is
None
:
module
=
'pyHealpix'
if
module
!=
'pyHealpix'
:
raise
ValueError
(
"Unsupported SHT module."
)
if
'pyHealpix'
not
in
gdi
:
raise
ImportError
(
"The module pyHealpix is needed but not available."
)
super
(
LMGLTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
=
module
)
super
(
LMGLTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
# ---Mandatory properties and methods---
...
...
@@ -97,6 +102,11 @@ class LMGLTransformation(SlicingTransformation):
super
(
LMGLTransformation
,
cls
).
check_codomain
(
domain
,
codomain
)
def
_transformation_of_slice
(
self
,
inp
,
**
kwargs
):
if
inp
.
dtype
not
in
(
np
.
float
,
np
.
complex
):
self
.
logger
.
warn
(
"The input array has dtype: %s. The FFT will "
"be performed at double precision."
%
str
(
inp
.
dtype
))
nlat
=
self
.
codomain
.
nlat
nlon
=
self
.
codomain
.
nlon
lmax
=
self
.
domain
.
lmax
...
...
@@ -107,7 +117,7 @@ class LMGLTransformation(SlicingTransformation):
sjob
.
set_triangular_alm_info
(
lmax
,
mmax
)
if
issubclass
(
inp
.
dtype
.
type
,
np
.
complexfloating
):
[
resultReal
,
resultImag
]
=
[
lm_transformation_
factory
.
buildLm
(
x
,
lmax
=
lmax
)
resultImag
]
=
[
lm_transformation_
helper
.
buildLm
(
x
,
lmax
=
lmax
)
for
x
in
(
inp
.
real
,
inp
.
imag
)]
[
resultReal
,
resultImag
]
=
[
sjob
.
alm2map
(
x
)
...
...
@@ -116,7 +126,7 @@ class LMGLTransformation(SlicingTransformation):
result
=
self
.
_combine_complex_result
(
resultReal
,
resultImag
)
else
:
result
=
lm_transformation_
factory
.
buildLm
(
inp
,
lmax
=
lmax
)
result
=
lm_transformation_
helper
.
buildLm
(
inp
,
lmax
=
lmax
)
result
=
sjob
.
alm2map
(
result
)
return
result
nifty/operators/fft_operator/transformations/lmhptransformation.py
View file @
64a27773
...
...
@@ -20,7 +20,7 @@ import numpy as np
from
nifty.config
import
dependency_injector
as
gdi
from
nifty
import
HPSpace
,
LMSpace
from
slicing_transformation
import
SlicingTransformation
import
lm_transformation_
factory
import
lm_transformation_
helper
pyHealpix
=
gdi
.
get
(
'pyHealpix'
)
...
...
@@ -30,12 +30,17 @@ class LMHPTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
module
is
None
:
module
=
'pyHealpix'
if
module
!=
'pyHealpix'
:
raise
ValueError
(
"Unsupported SHT module."
)
if
gdi
.
get
(
'pyHealpix'
)
is
None
:
raise
ImportError
(
"The module pyHealpix is needed but not available."
)
super
(
LMHPTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
=
module
)
super
(
LMHPTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
# ---Mandatory properties and methods---
...
...
@@ -85,13 +90,18 @@ class LMHPTransformation(SlicingTransformation):
super
(
LMHPTransformation
,
cls
).
check_codomain
(
domain
,
codomain
)
def
_transformation_of_slice
(
self
,
inp
,
**
kwargs
):
if
inp
.
dtype
not
in
(
np
.
float
,
np
.
complex
):
self
.
logger
.
warn
(
"The input array has dtype: %s. The FFT will "
"be performed at double precision."
%
str
(
inp
.
dtype
))
nside
=
self
.
codomain
.
nside
lmax
=
self
.
domain
.
lmax
mmax
=
lmax
if
issubclass
(
inp
.
dtype
.
type
,
np
.
complexfloating
):
[
resultReal
,
resultImag
]
=
[
lm_transformation_
factory
.
buildLm
(
x
,
lmax
=
lmax
)
resultImag
]
=
[
lm_transformation_
helper
.
buildLm
(
x
,
lmax
=
lmax
)
for
x
in
(
inp
.
real
,
inp
.
imag
)]
[
resultReal
,
resultImag
]
=
[
pyHealpix
.
alm2map
(
x
,
lmax
,
mmax
,
nside
)
...
...
@@ -100,7 +110,7 @@ class LMHPTransformation(SlicingTransformation):
result
=
self
.
_combine_complex_result
(
resultReal
,
resultImag
)
else
:
result
=
lm_transformation_
factory
.
buildLm
(
inp
,
lmax
=
lmax
)
result
=
lm_transformation_
helper
.
buildLm
(
inp
,
lmax
=
lmax
)
result
=
pyHealpix
.
alm2map
(
result
,
lmax
,
mmax
,
nside
)
return
result
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
64a27773
...
...
@@ -312,9 +312,8 @@ class FFTW(Transform):
try
:
# Create return object and insert results inplace
result_dtype
=
np
.
complex
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
result_dtype
)
dtype
=
np
.
complex
)
return_val
.
set_local_data
(
data
=
local_result
,
copy
=
False
)
except
(
AttributeError
):
return_val
=
local_result
...
...
@@ -341,9 +340,8 @@ class FFTW(Transform):
np
.
concatenate
([[
0
,
],
val
.
distributor
.
all_local_slices
[:,
2
]])
)
local_offset_Q
=
bool
(
local_offset_list
[
val
.
distributor
.
comm
.
rank
]
%
2
)
result_dtype
=
np
.
complex
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
result_dtype
)
dtype
=
np
.
complex
)
# Extract local data
local_val
=
val
.
get_local_data
(
copy
=
False
)
...
...
@@ -364,7 +362,7 @@ class FFTW(Transform):
if
temp_val
is
None
:
temp_val
=
np
.
empty_like
(
local_val
,
dtype
=
result_dtype
dtype
=
np
.
complex
)
inp
=
local_val
[
slice_list
]
...
...
@@ -439,6 +437,11 @@ class FFTW(Transform):
not
all
(
axis
in
range
(
len
(
val
.
shape
))
for
axis
in
axes
):
raise
ValueError
(
"Provided axes does not match array shape"
)
if
val
.
dtype
not
in
(
np
.
float
,
np
.
complex
):
self
.
logger
.
warn
(
"The input array has dtype: %s. The FFT will "
"be performed at double precision."
%
str
(
val
.
dtype
))
# If the input is a numpy array we transform it locally
if
not
isinstance
(
val
,
distributed_data_object
):
# Cast to a np.ndarray
...
...
@@ -583,9 +586,13 @@ class NUMPYFFT(Transform):
not
all
(
axis
in
range
(
len
(
val
.
shape
))
for
axis
in
axes
):
raise
ValueError
(
"Provided axes does not match array shape"
)
result_dtype
=
np
.
complex
if
val
.
dtype
not
in
(
np
.
float
,
np
.
complex
):
self
.
logger
.
warn
(
"The input array has dtype: %s. The FFT will "
"be performed at double precision."
%
str
(
val
.
dtype
))
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
result_dtype
)
dtype
=
np
.
complex
)
if
(
axes
is
None
)
or
(
0
in
axes
)
or
\
(
val
.
distribution_strategy
not
in
STRATEGIES
[
'slicing'
]):
...
...
nifty/operators/fft_operator/transformations/rgrgtransformation.py
View file @
64a27773
...
...
@@ -24,8 +24,7 @@ from nifty import RGSpace, nifty_configuration
class
RGRGTransformation
(
Transformation
):
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
super
(
RGRGTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
=
module
)
super
(
RGRGTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
if
module
is
None
:
if
nifty_configuration
[
'fft_module'
]
==
'fftw'
:
...
...
@@ -33,7 +32,7 @@ class RGRGTransformation(Transformation):
elif
nifty_configuration
[
'fft_module'
]
==
'numpy'
:
self
.
_transform
=
NUMPYFFT
(
self
.
domain
,
self
.
codomain
)
else
:
raise
ValueError
(
'
ERROR: unknow
default FFT module:'
+
raise
ValueError
(
'
Unsupported
default FFT module:'
+
nifty_configuration
[
'fft_module'
])
else
:
if
module
==
'fftw'
:
...
...
@@ -41,7 +40,7 @@ class RGRGTransformation(Transformation):
elif
module
==
'numpy'
:
self
.
_transform
=
NUMPYFFT
(
self
.
domain
,
self
.
codomain
)
else
:
raise
ValueError
(
'
ERROR: unknow
FFT module:'
+
module
)
raise
ValueError
(
'
Unsupported
FFT module:'
+
module
)
@
classmethod
def
get_codomain
(
cls
,
domain
,
zerocenter
=
None
):
...
...
nifty/operators/fft_operator/transformations/transformation.py
View file @
64a27773
...
...
@@ -18,8 +18,6 @@
import
abc
import
numpy
as
np
from
keepers
import
Loggable
...
...
@@ -40,7 +38,7 @@ class Transformation(Loggable, object):
self
.
codomain
=
codomain
@
classmethod
def
get_codomain
(
cls
,
domain
,
dtype
=
None
,
zerocenter
=
None
):
def
get_codomain
(
cls
,
domain
):
raise
NotImplementedError
@
classmethod
...
...
nifty/spaces/gl_space/gl_space.py
View file @
64a27773
...
...
@@ -186,6 +186,3 @@ class GLSpace(Space):
)
return
result
def
plot
(
self
):
pass
nifty/spaces/hp_space/hp_space.py
View file @
64a27773
...
...
@@ -145,6 +145,3 @@ class HPSpace(Space):
nside
=
hdf5_group
[
'nside'
][()],
)
return
result
def
plot
(
self
):
pass
nifty/spaces/lm_space/lm_space.py
View file @
64a27773
...
...
@@ -176,7 +176,3 @@ class LMSpace(Space):
lmax
=
hdf5_group
[
'lmax'
][()],
)
return
result
def
plot
(
self
):
pass
nifty/spaces/power_space/power_space.py
View file @
64a27773
...
...
@@ -24,7 +24,6 @@ from power_index_factory import PowerIndexFactory
from
nifty.spaces.space
import
Space
from
nifty.spaces.rg_space
import
RGSpace
from
nifty.nifty_utilities
import
cast_axis_to_tuple
class
PowerSpace
(
Space
):
...
...
@@ -200,14 +199,11 @@ class PowerSpace(Space):
new_ps
.
_pundex
=
hdf5_group
[
'pundex'
][:]
new_ps
.
_k_array
=
repository
.
get
(
'k_array'
,
hdf5_group
)
new_ps
.
_ignore_for_hash
+=
[
'_pindex'
,
'_kindex'
,
'_rho'
,
'_pundex'
,
'_k_array'
]
'_k_array'
]
return
new_ps
def
plot
(
self
):
pass
class
EmptyPowerSpace
(
PowerSpace
):
def
__init__
(
self
):
pass
nifty/spaces/rg_space/rg_space.py
View file @
64a27773
...
...
@@ -37,8 +37,6 @@ from d2o import distributed_data_object,\
from
nifty.spaces.space
import
Space
import
nifty.plotting
as
plt
class
RGSpace
(
Space
):
"""
...
...
@@ -315,8 +313,3 @@ class RGSpace(Space):
harmonic
=
hdf5_group
[
'harmonic'
][()],
)
return
result
def
plot
(
self
):
n_dimensions
=
len
(
self
.
_shape
)
# if n_dimensions == 1:
# fig = plt.figures.Figure(data=self.distances)
nifty/spaces/space/space.py
View file @
64a27773
...
...
@@ -141,8 +141,6 @@ from __future__ import division
import
abc
import
numpy
as
np
from
nifty.domain_object
import
DomainObject
...
...
Theo Steininger
@theos
mentioned in issue
#76 (closed)
·
May 01, 2017
mentioned in issue
#76 (closed)
mentioned in issue #76
Toggle commit list
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment