Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
24143e6a
Commit
24143e6a
authored
Apr 28, 2017
by
Martin Reinecke
Browse files
remove dtype from spaces
parent
7967f2d5
Changes
24
Hide whitespace changes
Inline
Side-by-side
demos/wiener_filter.py
View file @
24143e6a
from
nifty
import
*
#
import plotly.offline as pl
#
import plotly.graph_objs as go
import
plotly.offline
as
pl
import
plotly.graph_objs
as
go
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
...
...
@@ -10,10 +10,10 @@ rank = comm.rank
if
__name__
==
"__main__"
:
distribution_strategy
=
'
fftw
'
distribution_strategy
=
'
not
'
# Setting up the geometry
s_space
=
RGSpace
([
512
,
512
]
,
dtype
=
np
.
float64
)
s_space
=
RGSpace
([
512
,
512
])
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
...
...
@@ -50,8 +50,7 @@ if __name__ == "__main__":
d_data
=
d
.
val
.
get_full_data
().
real
m_data
=
m
.
val
.
get_full_data
().
real
ss_data
=
ss
.
val
.
get_full_data
().
real
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
if
rank
==
0
:
pl
.
plot
([
go
.
Heatmap
(
z
=
d_data
)],
filename
=
'data.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
m_data
)],
filename
=
'map.html'
)
pl
.
plot
([
go
.
Heatmap
(
z
=
ss_data
)],
filename
=
'map_orig.html'
)
demos/wiener_filter_hamiltonian.py
View file @
24143e6a
...
...
@@ -53,10 +53,10 @@ class WienerFilterEnergy(Energy):
if
__name__
==
"__main__"
:
distribution_strategy
=
'
fftw
'
distribution_strategy
=
'
not
'
# Set up spaces and fft transformation
s_space
=
RGSpace
([
512
,
512
]
,
dtype
=
np
.
float
)
s_space
=
RGSpace
([
512
,
512
])
fft
=
FFTOperator
(
s_space
)
h_space
=
fft
.
target
[
0
]
p_space
=
PowerSpace
(
h_space
,
distribution_strategy
=
distribution_strategy
)
...
...
nifty/domain_object.py
View file @
24143e6a
...
...
@@ -27,8 +27,7 @@ from keepers import Loggable,\
class
DomainObject
(
Versionable
,
Loggable
,
object
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
dtype
):
self
.
_dtype
=
np
.
dtype
(
dtype
)
def
__init__
(
self
):
self
.
_ignore_for_hash
=
[]
def
__hash__
(
self
):
...
...
@@ -50,10 +49,6 @@ class DomainObject(Versionable, Loggable, object):
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
abc
.
abstractproperty
def
shape
(
self
):
raise
NotImplementedError
(
...
...
@@ -78,10 +73,9 @@ class DomainObject(Versionable, Loggable, object):
# ---Serialization---
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
repository
):
result
=
cls
(
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
result
=
cls
()
return
result
nifty/field.py
View file @
24143e6a
...
...
@@ -94,8 +94,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple
=
(
np
.
dtype
(
gc
[
'default_field_dtype'
]),)
else
:
dtype_tuple
=
(
np
.
dtype
(
dtype
),)
if
domain
is
not
None
:
dtype_tuple
+=
tuple
(
np
.
dtype
(
sp
.
dtype
)
for
sp
in
domain
)
dtype
=
reduce
(
lambda
x
,
y
:
np
.
result_type
(
x
,
y
),
dtype_tuple
)
...
...
@@ -345,7 +343,7 @@ class Field(Loggable, Versionable, object):
# create random samples: one or two, depending on whether the
# power spectrum is real or complex
if
issubclass
(
power_domain
.
dtype
.
type
,
np
.
complexfloating
):
if
issubclass
(
self
.
dtype
.
type
,
np
.
complexfloating
):
result_list
=
[
None
,
None
]
else
:
result_list
=
[
None
]
...
...
@@ -355,7 +353,7 @@ class Field(Loggable, Versionable, object):
mean
=
mean
,
std
=
std
,
domain
=
result_domain
,
dtype
=
harmonic_domain
.
dtype
,
dtype
=
self
.
dtype
,
distribution_strategy
=
self
.
distribution_strategy
)
for
x
in
result_list
]
...
...
@@ -403,7 +401,7 @@ class Field(Loggable, Versionable, object):
lambda
x
:
x
*
local_rescaler
.
real
,
inplace
=
True
)
if
issubclass
(
power_domain
.
dtype
.
type
,
np
.
complexfloating
):
if
issubclass
(
self
.
dtype
.
type
,
np
.
complexfloating
):
result_val_list
[
1
].
apply_scalar_function
(
lambda
x
:
x
*
local_rescaler
.
imag
,
inplace
=
True
)
...
...
@@ -412,7 +410,7 @@ class Field(Loggable, Versionable, object):
[
x
.
set_val
(
new_val
=
y
,
copy
=
False
)
for
x
,
y
in
zip
(
result_list
,
result_val_list
)]
if
issubclass
(
power_domain
.
dtype
.
type
,
np
.
complexfloating
):
if
issubclass
(
self
.
dtype
.
type
,
np
.
complexfloating
):
result
=
result_list
[
0
]
+
1j
*
result_list
[
1
]
else
:
result
=
result_list
[
0
]
...
...
nifty/operators/fft_operator/transformations/lmgltransformation.py
View file @
24143e6a
...
...
@@ -69,7 +69,7 @@ class LMGLTransformation(SlicingTransformation):
nlat
=
domain
.
lmax
+
1
nlon
=
domain
.
lmax
*
2
+
1
result
=
GLSpace
(
nlat
=
nlat
,
nlon
=
nlon
,
dtype
=
domain
.
dtype
)
result
=
GLSpace
(
nlat
=
nlat
,
nlon
=
nlon
)
return
result
@
classmethod
...
...
nifty/operators/fft_operator/transformations/lmhptransformation.py
View file @
24143e6a
...
...
@@ -65,7 +65,7 @@ class LMHPTransformation(SlicingTransformation):
raise
TypeError
(
"domain needs to be a LMSpace."
)
nside
=
max
((
domain
.
lmax
+
1
)
//
2
,
1
)
result
=
HPSpace
(
nside
=
nside
,
dtype
=
domain
.
dtype
)
result
=
HPSpace
(
nside
=
nside
)
return
result
@
classmethod
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
24143e6a
...
...
@@ -312,7 +312,7 @@ class FFTW(Transform):
try
:
# Create return object and insert results inplace
result_dtype
=
np
.
result_type
(
np
.
complex
,
self
.
codomain
.
dtype
)
result_dtype
=
np
.
complex
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
result_dtype
)
return_val
.
set_local_data
(
data
=
local_result
,
copy
=
False
)
...
...
@@ -341,7 +341,7 @@ class FFTW(Transform):
np
.
concatenate
([[
0
,
],
val
.
distributor
.
all_local_slices
[:,
2
]])
)
local_offset_Q
=
bool
(
local_offset_list
[
val
.
distributor
.
comm
.
rank
]
%
2
)
result_dtype
=
np
.
result_type
(
np
.
complex
,
self
.
codomain
.
dtype
)
result_dtype
=
np
.
complex
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
result_dtype
)
...
...
@@ -583,7 +583,7 @@ class NUMPYFFT(Transform):
not
all
(
axis
in
range
(
len
(
val
.
shape
))
for
axis
in
axes
):
raise
ValueError
(
"Provided axes does not match array shape"
)
result_dtype
=
np
.
result_type
(
np
.
complex
,
self
.
codomain
.
dtype
)
result_dtype
=
np
.
complex
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
result_dtype
)
...
...
nifty/operators/fft_operator/transformations/rgrgtransformation.py
View file @
24143e6a
...
...
@@ -44,7 +44,7 @@ class RGRGTransformation(Transformation):
raise
ValueError
(
'ERROR: unknow FFT module:'
+
module
)
@
classmethod
def
get_codomain
(
cls
,
domain
,
dtype
=
None
,
zerocenter
=
None
):
def
get_codomain
(
cls
,
domain
,
zerocenter
=
None
):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ either a shifted grid or a Fourier conjugate
...
...
@@ -82,8 +82,7 @@ class RGRGTransformation(Transformation):
new_space
=
RGSpace
(
domain
.
shape
,
zerocenter
=
zerocenter
,
distances
=
distances
,
harmonic
=
(
not
domain
.
harmonic
),
dtype
=
domain
.
dtype
)
harmonic
=
(
not
domain
.
harmonic
))
# better safe than sorry
cls
.
check_codomain
(
domain
,
new_space
)
...
...
nifty/operators/fft_operator/transformations/transformation.py
View file @
24143e6a
...
...
@@ -45,9 +45,7 @@ class Transformation(Loggable, object):
@
classmethod
def
check_codomain
(
cls
,
domain
,
codomain
):
if
np
.
dtype
(
domain
.
dtype
)
!=
np
.
dtype
(
codomain
.
dtype
):
cls
.
Logger
.
warn
(
"Unrecommended: domain and codomain don't have "
"the same dtype."
)
pass
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
raise
NotImplementedError
nifty/spaces/gl_space/gl_space.py
View file @
24143e6a
...
...
@@ -45,8 +45,6 @@ class GLSpace(Space):
Number of latitudinal bins, or rings.
nlon : int, *optional*
Number of longitudinal bins (default: ``2*nlat - 1``).
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
See Also
--------
...
...
@@ -62,15 +60,11 @@ class GLSpace(Space):
High-Resolution Discretization and Fast Analysis of Data
Distributed on the Sphere", *ApJ* 622..759G.
Attributes
----------
dtype : numpy.dtype
Data type of the field values.
"""
# ---Overwritten properties and methods---
def
__init__
(
self
,
nlat
,
nlon
=
None
,
dtype
=
None
):
def
__init__
(
self
,
nlat
,
nlon
=
None
):
"""
Sets the attributes for a gl_space class instance.
...
...
@@ -80,8 +74,6 @@ class GLSpace(Space):
Number of latitudinal bins, or rings.
nlon : int, *optional*
Number of longitudinal bins (default: ``2*nlat - 1``).
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
Returns
-------
...
...
@@ -97,7 +89,7 @@ class GLSpace(Space):
raise
ImportError
(
"The module pyHealpix is needed but not available."
)
super
(
GLSpace
,
self
).
__init__
(
dtype
)
super
(
GLSpace
,
self
).
__init__
()
self
.
_nlat
=
self
.
_parse_nlat
(
nlat
)
self
.
_nlon
=
self
.
_parse_nlon
(
nlon
)
...
...
@@ -122,8 +114,7 @@ class GLSpace(Space):
def
copy
(
self
):
return
self
.
__class__
(
nlat
=
self
.
nlat
,
nlon
=
self
.
nlon
,
dtype
=
self
.
dtype
)
nlon
=
self
.
nlon
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
nlon
=
self
.
nlon
...
...
@@ -184,7 +175,6 @@ class GLSpace(Space):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'nlat'
]
=
self
.
nlat
hdf5_group
[
'nlon'
]
=
self
.
nlon
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
...
...
@@ -193,10 +183,9 @@ class GLSpace(Space):
result
=
cls
(
nlat
=
hdf5_group
[
'nlat'
][()],
nlon
=
hdf5_group
[
'nlon'
][()],
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
def
plot
(
self
):
pass
\ No newline at end of file
pass
nifty/spaces/hp_space/hp_space.py
View file @
24143e6a
...
...
@@ -54,15 +54,11 @@ class HPSpace(Space):
harmonic transforms revisited";
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
Attributes
----------
dtype : numpy.dtype
Data type of the field values, which is always numpy.float64.
"""
# ---Overwritten properties and methods---
def
__init__
(
self
,
nside
,
dtype
=
None
):
def
__init__
(
self
,
nside
):
"""
Sets the attributes for a hp_space class instance.
...
...
@@ -83,7 +79,7 @@ class HPSpace(Space):
"""
super
(
HPSpace
,
self
).
__init__
(
dtype
)
super
(
HPSpace
,
self
).
__init__
()
self
.
_nside
=
self
.
_parse_nside
(
nside
)
...
...
@@ -106,8 +102,7 @@ class HPSpace(Space):
return
4
*
np
.
pi
def
copy
(
self
):
return
self
.
__class__
(
nside
=
self
.
nside
,
dtype
=
self
.
dtype
)
return
self
.
__class__
(
nside
=
self
.
nside
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
weight
=
((
4
*
np
.
pi
)
/
(
12
*
self
.
nside
**
2
))
**
power
...
...
@@ -142,16 +137,14 @@ class HPSpace(Space):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'nside'
]
=
self
.
nside
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
repository
):
result
=
cls
(
nside
=
hdf5_group
[
'nside'
][()],
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
def
plot
(
self
):
pass
\ No newline at end of file
pass
nifty/spaces/lm_space/lm_space.py
View file @
24143e6a
...
...
@@ -42,8 +42,6 @@ class LMSpace(Space):
lmax : int
Maximum :math:`\ell`-value up to which the spherical harmonics
coefficients are to be used.
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.complex128).
See Also
--------
...
...
@@ -66,14 +64,9 @@ class LMSpace(Space):
.. [#] M. Reinecke and D. Sverre Seljebotn, 2013, "Libsharp - spherical
harmonic transforms revisited";
`arXiv:1303.4945 <http://www.arxiv.org/abs/1303.4945>`_
Attributes
----------
dtype : numpy.dtype
Data type of the field values.
"""
def
__init__
(
self
,
lmax
,
dtype
=
None
):
def
__init__
(
self
,
lmax
):
"""
Sets the attributes for an lm_space class instance.
...
...
@@ -82,8 +75,6 @@ class LMSpace(Space):
lmax : int
Maximum :math:`\ell`-value up to which the spherical harmonics
coefficients are to be used.
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.complex128).
Returns
-------
...
...
@@ -91,7 +82,7 @@ class LMSpace(Space):
"""
super
(
LMSpace
,
self
).
__init__
(
dtype
)
super
(
LMSpace
,
self
).
__init__
()
self
.
_lmax
=
self
.
_parse_lmax
(
lmax
)
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
,
...
...
@@ -125,11 +116,10 @@ class LMSpace(Space):
@
property
def
total_volume
(
self
):
# the individual pixels have a fixed volume of 1.
return
np
.
float
(
self
.
dim
)
return
np
.
float
64
(
self
.
dim
)
def
copy
(
self
):
return
self
.
__class__
(
lmax
=
self
.
lmax
,
dtype
=
self
.
dtype
)
return
self
.
__class__
(
lmax
=
self
.
lmax
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
if
inplace
:
...
...
@@ -143,7 +133,7 @@ class LMSpace(Space):
dists
=
dists
.
apply_scalar_function
(
lambda
x
:
self
.
_distance_array_helper
(
x
,
self
.
lmax
),
dtype
=
np
.
float
)
dtype
=
np
.
float
64
)
return
dists
...
...
@@ -178,17 +168,15 @@ class LMSpace(Space):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'lmax'
]
=
self
.
lmax
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
repository
):
result
=
cls
(
lmax
=
hdf5_group
[
'lmax'
][()],
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
def
plot
(
self
):
pass
\ No newline at end of file
pass
nifty/spaces/power_space/power_space.py
View file @
24143e6a
...
...
@@ -33,10 +33,9 @@ class PowerSpace(Space):
def
__init__
(
self
,
harmonic_domain
=
RGSpace
((
1
,)),
distribution_strategy
=
'not'
,
log
=
False
,
nbin
=
None
,
binbounds
=
None
,
dtype
=
None
):
log
=
False
,
nbin
=
None
,
binbounds
=
None
):
super
(
PowerSpace
,
self
).
__init__
(
dtype
)
super
(
PowerSpace
,
self
).
__init__
()
self
.
_ignore_for_hash
+=
[
'_pindex'
,
'_kindex'
,
'_rho'
,
'_pundex'
,
'_k_array'
]
...
...
@@ -97,8 +96,7 @@ class PowerSpace(Space):
distribution_strategy
=
distribution_strategy
,
log
=
self
.
log
,
nbin
=
self
.
nbin
,
binbounds
=
self
.
binbounds
,
dtype
=
self
.
dtype
)
binbounds
=
self
.
binbounds
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
reshaper
=
[
1
,
]
*
len
(
x
.
shape
)
...
...
@@ -171,7 +169,6 @@ class PowerSpace(Space):
hdf5_group
[
'kindex'
]
=
self
.
kindex
hdf5_group
[
'rho'
]
=
self
.
rho
hdf5_group
[
'pundex'
]
=
self
.
pundex
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
hdf5_group
[
'log'
]
=
self
.
log
# Store nbin as string, since it can be None
hdf5_group
.
attrs
[
'nbin'
]
=
str
(
self
.
nbin
)
...
...
@@ -190,7 +187,7 @@ class PowerSpace(Space):
# reset class
new_ps
.
__class__
=
cls
# call instructor so that classes are properly setup
super
(
PowerSpace
,
new_ps
).
__init__
(
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
super
(
PowerSpace
,
new_ps
).
__init__
()
# set all values
new_ps
.
_harmonic_domain
=
repository
.
get
(
'harmonic_domain'
,
hdf5_group
)
new_ps
.
_log
=
hdf5_group
[
'log'
][()]
...
...
nifty/spaces/rg_space/rg_space.py
View file @
24143e6a
...
...
@@ -52,9 +52,6 @@ class RGSpace(Space):
Attributes
----------
dtype : numpy.dtype
Data type of the field values for a field defined on this space,
either ``numpy.float64`` or ``numpy.complex128``.
harmonic : bool
Whether or not the grid represents a Fourier basis.
zerocenter : {bool, numpy.ndarray}, *optional*
...
...
@@ -67,7 +64,7 @@ class RGSpace(Space):
# ---Overwritten properties and methods---
def
__init__
(
self
,
shape
=
(
1
,),
zerocenter
=
False
,
distances
=
None
,
harmonic
=
False
,
dtype
=
None
):
harmonic
=
False
):
"""
Sets the attributes for an rg_space class instance.
...
...
@@ -92,13 +89,7 @@ class RGSpace(Space):
"""
self
.
_harmonic
=
bool
(
harmonic
)
if
dtype
is
None
:
if
self
.
harmonic
:
dtype
=
np
.
dtype
(
'complex'
)
else
:
dtype
=
np
.
dtype
(
'float'
)
super
(
RGSpace
,
self
).
__init__
(
dtype
)
super
(
RGSpace
,
self
).
__init__
()
self
.
_shape
=
self
.
_parse_shape
(
shape
)
self
.
_distances
=
self
.
_parse_distances
(
distances
)
...
...
@@ -198,8 +189,7 @@ class RGSpace(Space):
return
self
.
__class__
(
shape
=
self
.
shape
,
zerocenter
=
self
.
zerocenter
,
distances
=
self
.
distances
,
harmonic
=
self
.
harmonic
,
dtype
=
self
.
dtype
)
harmonic
=
self
.
harmonic
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
weight
=
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
distances
)
**
power
...
...
@@ -293,11 +283,11 @@ class RGSpace(Space):
def
_parse_distances
(
self
,
distances
):
if
distances
is
None
:
if
self
.
harmonic
:
temp
=
np
.
ones_like
(
self
.
shape
,
dtype
=
np
.
float
)
temp
=
np
.
ones_like
(
self
.
shape
,
dtype
=
np
.
float
64
)
else
:
temp
=
1
/
np
.
array
(
self
.
shape
,
dtype
=
np
.
float
)
temp
=
1
/
np
.
array
(
self
.
shape
,
dtype
=
np
.
float
64
)
else
:
temp
=
np
.
empty
(
len
(
self
.
shape
),
dtype
=
np
.
float
)
temp
=
np
.
empty
(
len
(
self
.
shape
),
dtype
=
np
.
float
64
)
temp
[:]
=
distances
return
tuple
(
temp
)
...
...
@@ -313,7 +303,6 @@ class RGSpace(Space):
hdf5_group
[
'zerocenter'
]
=
self
.
zerocenter
hdf5_group
[
'distances'
]
=
self
.
distances
hdf5_group
[
'harmonic'
]
=
self
.
harmonic
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
...
...
@@ -324,7 +313,6 @@ class RGSpace(Space):
zerocenter
=
hdf5_group
[
'zerocenter'
][:],
distances
=
hdf5_group
[
'distances'
][:],
harmonic
=
hdf5_group
[
'harmonic'
][()],
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
...
...
nifty/spaces/space/space.py
View file @
24143e6a
...
...
@@ -147,24 +147,18 @@ from nifty.domain_object import DomainObject
class
Space
(
DomainObject
):
def
__init__
(
self
,
dtype
=
np
.
dtype
(
'float'
)
):
def
__init__
(
self
):
"""
Parameters
----------
dtype : numpy.dtype, *optional*
Data type of the field values (default: numpy.float64).
None.
Returns
-------
None.
"""
# parse dtype
casted_dtype
=
np
.
result_type
(
dtype
,
np
.
float64
)
if
casted_dtype
!=
dtype
:
self
.
Logger
.
warning
(
"Input dtype reset to: %s"
%
str
(
casted_dtype
))
super
(
Space
,
self
).
__init__
(
dtype
=
casted_dtype
)
super
(
Space
,
self
).
__init__
()
self
.
_ignore_for_hash
+=
[
'_global_id'
]
@
abc
.
abstractproperty
...
...
@@ -178,7 +172,7 @@ class Space(DomainObject):
@
abc
.
abstractmethod
def
copy
(
self
):
return
self
.
__class__
(
dtype
=
self
.
dtype
)
return
self
.
__class__
()
def
get_distance_array
(
self
,
distribution_strategy
):
raise
NotImplementedError
(
...
...
@@ -195,5 +189,4 @@ class Space(DomainObject):
def
__repr__
(
self
):
string
=
""
string
+=
str
(
type
(
self
))
+
"
\n
"
string
+=
"dtype: "
+
str
(
self
.
dtype
)
+
"
\n
"
return
string
nifty/sugar.py
View file @
24143e6a
...
...
@@ -31,11 +31,10 @@ def create_power_operator(domain, power_spectrum, dtype=None,
domain
=
fft
.
target
[
0
]
power_domain
=
PowerSpace
(
domain
,
dtype
=
dtype
,
distribution_strategy
=
distribution_strategy
)
fp
=
Field
(
power_domain
,
val
=
power_spectrum
,
val
=
power_spectrum
,
dtype
=
dtype
,
distribution_strategy
=
distribution_strategy
)
f
=
fp
.
power_synthesize
(
mean
=
1
,
std
=
0
,
real_signal
=
False
)
...
...
test/test_field.py
View file @
24143e6a
...
...
@@ -35,14 +35,13 @@ from test.common import expand
np
.
random
.
seed
(
123
)
SPACES
=
[
RGSpace
((
4
,)
,
dtype
=
np
.
float
),
RGSpace
((
5
),
dtype
=
np
.
complex
)]
SPACES
=
[
RGSpace
((
4
,)
),
RGSpace
((
5
)
)]
SPACE_COMBINATIONS
=
[(),
SPACES
[
0
],
SPACES
[
1
],
SPACES
]
class
Test_Interface
(
unittest
.
TestCase
):
@
expand
(
product
(
SPACE_COMBINATIONS
,
[[
'dtype'
,
np
.
dtype
],
[
'distribution_strategy'
,
str
],
[[
'distribution_strategy'
,
str
],
[
'domain'
,
tuple
],
[
'domain_axes'
,
tuple
],
[
'val'
,
distributed_data_object
],
...
...
test/test_misc.py
View file @
24143e6a
...
...
@@ -43,6 +43,11 @@ def _harmonic_type(itp):
elif
otp
==
np
.
float32
:
otp
=
np
.
complex64
return
otp
def
_get_rtol
(
tp
):
if
(
tp
==
np
.
float64
)
or
(
tp
==
np
.
complex128
):
return
1e-10
else
:
return
1e-5
class
Misc_Tests
(
unittest
.
TestCase
):
...
...
@@ -70,12 +75,13 @@ class Misc_Tests(unittest.TestCase):
for
zc2
in
[
False
,
True