Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
b5e38b94
Commit
b5e38b94
authored
May 29, 2016
by
csongor
Browse files
WIP: fix _complement_cast
parent
0ba20735
Pipeline
#3934
skipped
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
lm/nifty_lm.py
View file @
b5e38b94
...
...
@@ -272,17 +272,23 @@ class lm_space(point_space):
mol
[
self
.
paradict
[
'lmax'
]
+
1
:]
=
2
# redundant: (l,m) and (l,-m)
return
mol
def
_cast_to_d2o
(
self
,
x
,
dtype
=
None
,
**
kwargs
):
casted_x
=
super
(
lm_space
,
self
).
_cast_to_d2o
(
x
=
x
,
dtype
=
dtype
,
**
kwargs
)
lmax
=
self
.
paradict
[
'lmax'
]
complexity_mask
=
casted_x
[:
lmax
+
1
].
iscomplex
()
if
complexity_mask
.
any
():
about
.
warnings
.
cprint
(
"WARNING: Taking the absolute values for "
+
"all complex entries where lmax==0"
)
casted_x
[:
lmax
+
1
]
=
abs
(
casted_x
[:
lmax
+
1
])
return
casted_x
def
_complement_cast
(
self
,
x
,
axis
=
None
,
**
kwargs
):
if
axis
is
None
:
lmax
=
self
.
paradict
[
'lmax'
]
complexity_mask
=
x
[:
lmax
+
1
].
iscomplex
()
if
complexity_mask
.
any
():
about
.
warnings
.
cprint
(
"WARNING: Taking the absolute values for "
+
"all complex entries where lmax==0"
)
x
[:
lmax
+
1
]
=
abs
(
x
[:
lmax
+
1
])
else
:
# TODO hermitianize only on specific axis
lmax
=
self
.
paradict
[
'lmax'
]
complexity_mask
=
x
[:
lmax
+
1
].
iscomplex
()
if
complexity_mask
.
any
():
about
.
warnings
.
cprint
(
"WARNING: Taking the absolute values for "
+
"all complex entries where lmax==0"
)
x
[:
lmax
+
1
]
=
abs
(
x
[:
lmax
+
1
])
return
x
# TODO: Extend to binning/log
def
enforce_power
(
self
,
spec
,
size
=
None
,
kindex
=
None
):
...
...
nifty_core.py
View file @
b5e38b94
...
...
@@ -319,31 +319,8 @@ class space(object):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'dof'."
))
def
cast
(
self
,
x
,
verbose
=
False
):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'cast'."
))
def
_complement_cast
(
self
,
x
,
axis
=
None
):
return
x
# TODO: Move enforce power into power_indices class
def
enforce_power
(
self
,
spec
,
**
kwargs
):
...
...
@@ -1006,93 +983,6 @@ class point_space(space):
mol
=
self
.
cast
(
1
,
dtype
=
np
.
dtype
(
'float'
))
return
self
.
calc_weight
(
mol
,
power
=
1
)
def
cast
(
self
,
x
=
None
,
dtype
=
None
,
**
kwargs
):
return
self
.
_cast_to_d2o
(
x
=
x
,
dtype
=
dtype
,
**
kwargs
)
def
_cast_to_d2o
(
self
,
x
,
dtype
=
None
,
**
kwargs
):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if
dtype
is
not
None
:
dtype
=
np
.
dtype
(
dtype
)
if
dtype
is
None
:
dtype
=
self
.
dtype
# Case 1: x is a distributed_data_object
if
isinstance
(
x
,
distributed_data_object
):
to_copy
=
False
# Check the shape
if
np
.
any
(
np
.
array
(
x
.
shape
)
!=
np
.
array
(
self
.
get_shape
())):
# Check if at least the number of degrees of freedom is equal
if
x
.
get_dim
()
==
self
.
get_dim
():
try
:
temp
=
x
.
copy_empty
(
global_shape
=
self
.
get_shape
())
temp
.
set_local_data
(
x
.
get_local_data
(),
copy
=
False
)
except
:
# If the number of dof is equal or 1, use np.reshape...
about
.
warnings
.
cflush
(
"WARNING: Trying to reshape the data. This "
+
"operation is expensive as it consolidates the "
+
"full data!
\n
"
)
temp
=
x
.
get_full_data
()
temp
=
np
.
reshape
(
temp
,
self
.
get_shape
())
# ... and cast again
return
self
.
_cast_to_d2o
(
temp
,
dtype
=
dtype
,
**
kwargs
)
else
:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: Data has incompatible shape!"
))
# Check the dtype
if
x
.
dtype
!=
dtype
:
if
x
.
dtype
>
dtype
:
about
.
warnings
.
cflush
(
"WARNING: Datatypes are of conflicting precision "
+
"(own: "
+
str
(
dtype
)
+
" <> foreign: "
+
str
(
x
.
dtype
)
+
") and will be casted! Potential "
+
"loss of precision!
\n
"
)
to_copy
=
True
if
to_copy
:
temp
=
x
.
copy_empty
(
dtype
=
dtype
)
temp
.
set_data
(
to_key
=
(
slice
(
None
),),
data
=
x
,
from_key
=
(
slice
(
None
),))
temp
.
hermitian
=
x
.
hermitian
x
=
temp
return
x
# Case 2: x is something else
# Use general d2o casting
else
:
x
=
distributed_data_object
(
x
,
global_shape
=
self
.
get_shape
(),
dtype
=
dtype
)
# Cast the d2o
return
self
.
cast
(
x
,
dtype
=
dtype
)
def
enforce_power
(
self
,
spec
,
**
kwargs
):
"""
Raises an error since the power spectrum is ill-defined for point
...
...
nifty_field.py
View file @
b5e38b94
...
...
@@ -102,9 +102,9 @@ class field(object):
"""
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
codomain
=
None
,
comm
=
gc
[
'default_comm'
],
copy
=
False
,
dtype
=
np
.
dtype
(
'float64'
),
datamodel
=
'not'
,
**
kwargs
):
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
codomain
=
None
,
comm
=
gc
[
'default_comm'
],
copy
=
False
,
dtype
=
np
.
dtype
(
'float64'
),
datamodel
=
'fftw'
,
**
kwargs
):
"""
Sets the attributes for a field class instance.
...
...
@@ -149,15 +149,14 @@ class field(object):
**
kwargs
)
def
_init_from_field
(
self
,
f
,
domain
,
codomain
,
comm
,
copy
,
dtype
,
datamodel
,
**
kwargs
):
datamodel
,
**
kwargs
):
# check domain
if
domain
is
None
:
domain
=
f
.
domain
# check codomain
if
codomain
is
None
:
if
self
.
check_codomain
(
domain
,
f
.
codomain
):
if
self
.
_
check_codomain
(
domain
,
f
.
codomain
):
codomain
=
f
.
codomain
else
:
codomain
=
self
.
get_codomain
(
domain
)
...
...
@@ -181,10 +180,18 @@ class field(object):
def
_init_from_array
(
self
,
val
,
domain
,
codomain
,
comm
,
copy
,
dtype
,
datamodel
,
**
kwargs
):
if
dtype
is
None
:
dtype
=
np
.
dtype
(
'float64'
)
dtype
=
self
.
_get_dtype_from_domain
(
domain
)
self
.
dtype
=
dtype
self
.
comm
=
self
.
_parse_comm
(
comm
)
# if val is a distributed data object, we take it's datamodel,
# since we don't want to redistribute large amounts of data, if not
# necessary
if
isinstance
(
val
,
distributed_data_object
):
if
datamodel
!=
val
.
distribution_strategy
:
about
.
warnings
.
cprint
(
"WARNING: datamodel set to val's "
"datamodel."
)
datamodel
=
val
.
distribution_strategy
if
datamodel
not
in
DISTRIBUTION_STRATEGIES
[
'global'
]:
about
.
warnings
.
cprint
(
"WARNING: datamodel set to default."
)
self
.
datamodel
=
\
...
...
@@ -192,12 +199,13 @@ class field(object):
else
:
self
.
datamodel
=
datamodel
# check domain
self
.
domain
=
self
.
check_valid_domain
(
domain
=
domain
)
self
.
domain
=
self
.
_check_valid_domain
(
domain
=
domain
)
self
.
_axis_list
=
self
.
_get_axis_list_from_domain
(
domain
=
domain
)
# check codomain
if
codomain
is
None
:
codomain
=
self
.
get_codomain
(
domain
)
elif
not
self
.
check_codomain
(
domain
=
domain
,
codomain
=
codomain
):
codomain
=
self
.
get_codomain
(
domain
=
domain
)
elif
not
self
.
_
check_codomain
(
domain
=
domain
,
codomain
=
codomain
):
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: The given codomain is not compatible to the domain."
))
self
.
codomain
=
codomain
...
...
@@ -210,6 +218,19 @@ class field(object):
codomain
=
z
,
**
kwargs
),
self
.
codomain
)
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
_get_dtype_from_domain
(
self
,
domain
=
None
):
if
domain
is
None
:
domain
=
self
.
domain
dtype_tuple
=
tuple
(
space
.
dtype
for
space
in
domain
)
dtype
=
np
.
result_type
(
dtype_tuple
)
return
dtype
def
_get_axis_list_from_domain
(
self
,
domain
=
None
):
if
domain
is
None
:
domain
=
self
.
domain
axis_list
=
[
space
.
get_shape
()
for
space
in
domain
]
return
axis_list
def
_parse_comm
(
self
,
comm
):
# check if comm is a string -> the name of comm is given
# -> Extract it from the mpi_module
...
...
@@ -229,7 +250,7 @@ class field(object):
"default-MPI-module's Intracomm Class."
))
return
result_comm
def
check_valid_domain
(
self
,
domain
):
def
_
check_valid_domain
(
self
,
domain
):
if
not
isinstance
(
domain
,
tuple
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: The given domain is not a list."
))
...
...
@@ -237,13 +258,13 @@ class field(object):
if
not
isinstance
(
d
,
space
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: Given domain is not a space."
))
elif
d
.
dtype
!=
self
.
dtype
:
elif
d
.
dtype
>
self
.
dtype
:
raise
AttributeError
(
about
.
_errors
.
cstring
(
"ERROR: The dtype of a space in the domain
m
is
smatches
"
"ERROR: The dtype of a space in the domain is
larger than
"
"the field's dtype."
))
return
domain
def
check_codomain
(
self
,
domain
,
codomain
):
def
_
check_codomain
(
self
,
domain
,
codomain
):
if
codomain
is
None
:
return
False
if
len
(
domain
)
==
len
(
codomain
):
...
...
@@ -453,9 +474,7 @@ class field(object):
temp
=
x
temp
=
np
.
reshape
(
temp
,
shape
)
# ... and cast again
return
self
.
_cast_to_d2o
(
temp
,
dtype
=
dtype
,
**
kwargs
)
return
self
.
_cast_to_d2o
(
temp
,
dtype
=
dtype
,
**
kwargs
)
else
:
raise
ValueError
(
about
.
_errors
.
cstring
(
...
...
@@ -497,7 +516,8 @@ class field(object):
return
self
.
cast
(
x
,
dtype
=
dtype
)
def
_complement_cast
(
self
,
x
):
# TODO implement complement cast for multiple spaces.
for
ind
,
space
in
enumerate
(
self
.
domain
):
space
.
_complement_cast
(
x
,
axis
=
self
.
_axis_list
[
ind
])
return
x
def
set_domain
(
self
,
new_domain
=
None
,
force
=
False
):
...
...
rg/nifty_rg.py
View file @
b5e38b94
...
...
@@ -251,30 +251,23 @@ class rg_space(point_space):
def
get_shape
(
self
):
return
tuple
(
self
.
paradict
[
'shape'
])
def
_cast_to_d2o
(
self
,
x
,
dtype
=
None
,
hermitianize
=
True
,
**
kwargs
):
casted_x
=
super
(
rg_space
,
self
).
_cast_to_d2o
(
x
=
x
,
dtype
=
dtype
,
**
kwargs
)
if
x
is
not
None
and
hermitianize
and
\
self
.
paradict
[
'complexity'
]
==
1
and
not
casted_x
.
hermitian
:
about
.
warnings
.
cflush
(
"WARNING: Data gets hermitianized. This operation is "
+
"extremely expensive
\n
"
)
casted_x
=
utilities
.
hermitianize
(
casted_x
)
return
casted_x
def
_cast_to_np
(
self
,
x
,
dtype
=
None
,
hermitianize
=
True
,
**
kwargs
):
casted_x
=
super
(
rg_space
,
self
).
_cast_to_np
(
x
=
x
,
dtype
=
dtype
,
**
kwargs
)
if
x
is
not
None
and
hermitianize
and
self
.
paradict
[
'complexity'
]
==
1
:
about
.
warnings
.
cflush
(
"WARNING: Data gets hermitianized. This operation is "
+
"extremely expensive
\n
"
)
casted_x
=
utilities
.
hermitianize
(
casted_x
)
return
casted_x
def
_complement_cast
(
self
,
x
,
axis
=
None
,
hermitianize
=
True
):
if
axis
is
None
:
if
x
is
not
None
and
hermitianize
and
self
.
paradict
[
'complexity'
]
\
==
1
and
not
x
.
hermitian
:
about
.
warnings
.
cflush
(
"WARNING: Data gets hermitianized. This operation is "
+
"extremely expensive
\n
"
)
x
=
utilities
.
hermitianize
(
x
)
else
:
# TODO hermitianize only on specific axis
if
x
is
not
None
and
hermitianize
and
self
.
paradict
[
'complexity'
]
\
==
1
and
not
x
.
hermitian
:
about
.
warnings
.
cflush
(
"WARNING: Data gets hermitianized. This operation is "
+
"extremely expensive
\n
"
)
x
=
utilities
.
hermitianize
(
x
)
return
x
def
enforce_power
(
self
,
spec
,
size
=
None
,
kindex
=
None
,
codomain
=
None
,
**
kwargs
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment