Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
25f052c3
Commit
25f052c3
authored
May 17, 2016
by
csongor
Browse files
WIP: Field support for multiple spaces.
parent
e735aeae
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty_core.py
View file @
25f052c3
...
...
@@ -911,9 +911,11 @@ class point_space(space):
'mean'
:
lambda
y
:
getattr
(
y
,
'mean'
)(
axis
=
axis
),
'std'
:
lambda
y
:
getattr
(
y
,
'std'
)(
axis
=
axis
),
'var'
:
lambda
y
:
getattr
(
y
,
'var'
)(
axis
=
axis
),
'argmin_nonflat'
:
lambda
y
:
getattr
(
y
,
'argmin_nonflat'
)(
axis
=
axis
),
'argmin_nonflat'
:
lambda
y
:
getattr
(
y
,
'argmin_nonflat'
)(
axis
=
axis
),
'argmin'
:
lambda
y
:
getattr
(
y
,
'argmin'
)(
axis
=
axis
),
'argmax_nonflat'
:
lambda
y
:
getattr
(
y
,
'argmax_nonflat'
)(
axis
=
axis
),
'argmax_nonflat'
:
lambda
y
:
getattr
(
y
,
'argmax_nonflat'
)(
axis
=
axis
),
'argmax'
:
lambda
y
:
getattr
(
y
,
'argmax'
)(
axis
=
axis
),
'conjugate'
:
lambda
y
:
getattr
(
y
,
'conjugate'
)(),
'sum'
:
lambda
y
:
getattr
(
y
,
'sum'
)(
axis
=
axis
),
...
...
@@ -1038,25 +1040,7 @@ class point_space(space):
return
self
.
calc_weight
(
mol
,
power
=
1
)
def
cast
(
self
,
x
=
None
,
dtype
=
None
,
**
kwargs
):
if
dtype
is
not
None
:
dtype
=
np
.
dtype
(
dtype
)
# If x is a field, extract the data and do a recursive call
if
isinstance
(
x
,
field
):
# Check if the domain matches
if
self
!=
x
.
domain
:
about
.
warnings
.
cflush
(
"WARNING: Getting data from foreign domain!"
)
# Extract the data, whatever it is, and cast it again
return
self
.
cast
(
x
.
val
,
dtype
=
dtype
,
**
kwargs
)
else
:
return
self
.
_cast_to_d2o
(
x
=
x
,
dtype
=
dtype
,
**
kwargs
)
return
self
.
_cast_to_d2o
(
x
=
x
,
dtype
=
dtype
,
**
kwargs
)
def
_cast_to_d2o
(
self
,
x
,
dtype
=
None
,
**
kwargs
):
"""
...
...
@@ -1081,6 +1065,8 @@ class point_space(space):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if
dtype
is
not
None
:
dtype
=
np
.
dtype
(
dtype
)
if
dtype
is
None
:
dtype
=
self
.
dtype
...
...
@@ -1357,8 +1343,8 @@ class point_space(space):
processed_std
=
std
else
:
try
:
processed_std
=
sample
.
distributor
.
\
extract_local_data
(
std
)
processed_std
=
sample
.
distributor
.
\
extract_local_data
(
std
)
except
(
AttributeError
):
processed_std
=
std
...
...
@@ -1375,8 +1361,6 @@ class point_space(space):
vmax
=
arg
[
'vmax'
]))
return
sample
def
calc_weight
(
self
,
x
,
power
=
1
):
"""
Weights a given array of field values with the pixel volumes (not
...
...
@@ -1575,7 +1559,7 @@ class point_space(space):
ax0
=
fig
.
add_axes
([
0.12
,
0.12
,
0.82
,
0.76
])
xaxes
=
np
.
arange
(
self
.
para
[
0
],
dtype
=
np
.
dtype
(
'int'
))
if
(
norm
==
"log"
)
and
(
vmin
<=
0
):
if
(
norm
==
"log"
)
and
(
vmin
<=
0
):
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: nonpositive value(s)."
))
...
...
@@ -1741,8 +1725,9 @@ class field(object):
"""
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
codomain
=
None
,
ishape
=
None
,
copy
=
False
,
**
kwargs
):
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
codomain
=
None
,
copy
=
False
,
dtype
=
np
.
dtype
(
'float64'
),
datamodel
=
'not'
,
**
kwargs
):
"""
Sets the attributes for a field class instance.
...
...
@@ -1771,32 +1756,31 @@ class field(object):
self
.
_init_from_field
(
f
=
val
,
domain
=
domain
,
codomain
=
codomain
,
ishape
=
ishape
,
copy
=
copy
,
dtype
=
dtype
,
datamodel
=
datamodel
,
**
kwargs
)
else
:
self
.
_init_from_array
(
val
=
val
,
domain
=
domain
,
codomain
=
codomain
,
ishape
=
ishape
,
copy
=
copy
,
dtype
=
dtype
,
datamodel
=
datamodel
,
**
kwargs
)
def
_init_from_field
(
self
,
f
,
domain
,
codomain
,
ishape
,
copy
,
**
kwargs
):
def
_init_from_field
(
self
,
f
,
domain
,
codomain
,
copy
,
dtype
,
datamodel
,
**
kwargs
):
# check domain
if
domain
is
None
:
domain
=
f
.
domain
# check codomain
if
codomain
is
None
:
if
domain
.
check_codomain
(
f
.
codomain
):
if
self
.
check_codomain
(
domain
,
f
.
codomain
):
codomain
=
f
.
codomain
else
:
codomain
=
domain
.
get_codomain
()
# check for ishape
if
ishape
is
None
:
ishape
=
f
.
ishape
codomain
=
self
.
get_codomain
(
domain
)
# Check if the given field lives in a space which is compatible to the
# given domain
...
...
@@ -1808,51 +1792,78 @@ class field(object):
self
.
_init_from_array
(
domain
=
domain
,
val
=
f
.
val
,
codomain
=
codomain
,
ishape
=
ishape
,
copy
=
copy
,
dtype
=
dtype
,
datamodel
=
datamodel
,
**
kwargs
)
def
_init_from_array
(
self
,
val
,
domain
,
codomain
,
ishape
,
copy
,
**
kwargs
):
def
_init_from_array
(
self
,
val
,
domain
,
codomain
,
copy
,
dtype
,
datamodel
,
**
kwargs
):
if
dtype
is
None
:
dtype
=
np
.
dtype
(
'float64'
)
self
.
dtype
=
dtype
if
datamodel
not
in
DISTRIBUTION_STRATEGIES
[
'global'
]:
about
.
warnings
.
cprint
(
"WARNING: datamodel set to default."
)
self
.
datamodel
=
\
gc
[
'default_distribution_strategy'
]
else
:
self
.
datamodel
=
datamodel
# check domain
if
not
isinstance
(
domain
,
space
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: Given domain is not a space."
))
self
.
domain
=
domain
self
.
domain
=
self
.
check_valid_domain
(
domain
=
domain
)
# check codomain
if
codomain
is
None
:
codomain
=
domain
.
get_codomain
()
elif
not
self
.
domain
.
check_codomain
(
codomain
):
codomain
=
self
.
get_codomain
(
domain
)
elif
not
self
.
check_codomain
(
domain
=
domain
,
codomain
=
codomain
):
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: The given codomain is not compatible to the domain."
))
self
.
codomain
=
codomain
if
ishape
is
not
None
:
ishape
=
tuple
(
np
.
array
(
ishape
,
dtype
=
np
.
uint
).
flatten
())
elif
val
is
not
None
:
try
:
if
val
.
dtype
.
type
==
np
.
object_
:
ishape
=
val
.
shape
else
:
ishape
=
()
except
(
AttributeError
):
try
:
ishape
=
val
.
ishape
except
(
AttributeError
):
ishape
=
()
else
:
ishape
=
()
self
.
ishape
=
ishape
if
val
is
None
:
if
kwargs
==
{}:
val
=
self
.
_map
(
lambda
:
self
.
domain
.
cast
(
0.
))
val
=
self
.
_map
(
lambda
:
self
.
cast
(
(
0
,)
))
else
:
val
=
self
.
_map
(
lambda
:
self
.
domain
.
get_random_values
(
codomain
=
self
.
codomain
,
**
kwargs
))
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
check_valid_domain
(
self
,
domain
):
if
not
isinstance
(
domain
,
np
.
ndarray
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: The given domain is not a list."
))
for
d
in
domain
:
if
not
isinstance
(
d
,
space
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: Given domain is not a space."
))
elif
d
.
dtype
!=
self
.
dtype
:
raise
AttributeError
(
about
.
_errors
.
cstring
(
"ERROR: The dtype of a space in the domain missmatches "
"the field's dtype."
))
elif
d
.
datamodel
!=
self
.
datamodel
:
raise
AttributeError
(
about
.
_errors
.
cstring
(
"ERROR: The datamodel of a space in the domain missmatches "
"the field's datamodel."
))
return
domain
def
check_codomain
(
self
,
domain
,
codomain
):
if
codomain
is
None
:
return
False
if
domain
.
shape
==
codomain
.
shape
:
return
np
.
all
(
map
((
lambda
d
,
c
:
d
.
_check_codomain
(
c
)),
domain
,
codomain
))
else
:
return
False
def
get_codomain
(
self
,
domain
):
if
domain
.
shape
==
(
1
,):
return
np
.
array
(
domain
[
0
].
get_codomain
())
else
:
# TODO implement for multiple domain get_codomain need
# calc_transform
return
np
.
empty
((
0
,))
def
__len__
(
self
):
return
int
(
self
.
get_dim
(
split
=
True
)[
0
])
...
...
@@ -2010,29 +2021,127 @@ class field(object):
def
get_ishape
(
self
):
return
self
.
ishape
def
get_global_shape
(
self
):
global_shape
=
np
.
sum
([
space
.
get_shape
()
for
space
in
self
.
domain
])
if
isinstance
(
global_shape
,
tuple
):
return
global_shape
else
:
return
()
def
_map
(
self
,
function
,
*
args
):
return
utilities
.
field_map
(
self
.
i
shape
,
function
,
*
args
)
return
utilities
.
field_map
(
self
.
get_global_
shape
()
,
function
,
*
args
)
def
cast
(
self
,
x
=
None
,
ishape
=
None
):
if
ishape
is
None
:
ishape
=
self
.
ishape
casted_x
=
self
.
_cast_to_ishape
(
x
,
ishape
=
ishape
)
if
ishape
==
():
def
cast
(
self
,
x
=
None
,
dtype
=
None
):
if
dtype
is
not
None
:
dtype
=
np
.
dtype
(
dtype
)
if
dtype
is
None
:
dtype
=
self
.
dtype
casted_x
=
self
.
_cast_to_shape
(
x
)
if
self
.
get_global_shape
()
==
():
return
self
.
domain
.
cast
(
casted_x
)
else
:
return
self
.
_map
(
lambda
z
:
self
.
domain
.
cast
(
z
),
casted_x
)
def
_cast_to_ishape
(
self
,
x
,
ishape
=
None
):
if
ishape
is
None
:
ishape
=
self
.
ishape
def
_cast_to_d2o
(
self
,
x
,
dtype
=
None
,
**
kwargs
):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if
dtype
is
None
:
dtype
=
self
.
dtype
# Case 1: x is a distributed_data_object
if
isinstance
(
x
,
distributed_data_object
):
to_copy
=
False
# Check the shape
if
np
.
any
(
np
.
array
(
x
.
shape
)
!=
np
.
array
(
self
.
get_shape
())):
# Check if at least the number of degrees of freedom is equal
if
x
.
get_dim
()
==
self
.
get_dim
():
try
:
temp
=
x
.
copy_empty
(
global_shape
=
self
.
get_shape
())
temp
.
set_local_data
(
x
.
get_local_data
(),
copy
=
False
)
except
:
# If the number of dof is equal or 1, use np.reshape...
about
.
warnings
.
cflush
(
"WARNING: Trying to reshape the data. This "
+
"operation is expensive as it consolidates the "
+
"full data!
\n
"
)
temp
=
x
.
get_full_data
()
temp
=
np
.
reshape
(
temp
,
self
.
get_shape
())
# ... and cast again
return
self
.
_cast_to_d2o
(
temp
,
dtype
=
dtype
,
**
kwargs
)
else
:
raise
ValueError
(
about
.
_errors
.
cstring
(
"ERROR: Data has incompatible shape!"
))
# Check the dtype
if
x
.
dtype
!=
dtype
:
if
x
.
dtype
>
dtype
:
about
.
warnings
.
cflush
(
"WARNING: Datatypes are of conflicting precision "
+
"(own: "
+
str
(
dtype
)
+
" <> foreign: "
+
str
(
x
.
dtype
)
+
") and will be casted! Potential "
+
"loss of precision!
\n
"
)
to_copy
=
True
# Check the distribution_strategy
if
x
.
distribution_strategy
!=
self
.
datamodel
:
to_copy
=
True
if
to_copy
:
temp
=
x
.
copy_empty
(
dtype
=
dtype
,
distribution_strategy
=
self
.
datamodel
)
temp
.
set_data
(
to_key
=
(
slice
(
None
),),
data
=
x
,
from_key
=
(
slice
(
None
),))
temp
.
hermitian
=
x
.
hermitian
x
=
temp
return
x
# Case 2: x is something else
# Use general d2o casting
else
:
x
=
distributed_data_object
(
x
,
global_shape
=
self
.
get_shape
(),
dtype
=
dtype
,
distribution_strategy
=
self
.
datamodel
)
# Cast the d2o
return
self
.
cast
(
x
,
dtype
=
dtype
)
def
_cast_to_shape
(
self
,
x
):
if
isinstance
(
x
,
field
):
x
=
x
.
get_val
()
if
ishape
==
():
global_shape
=
self
.
get_global_shape
()
if
global_shape
==
():
casted_x
=
self
.
_cast_to_scalar_helper
(
x
)
else
:
casted_x
=
self
.
_cast_to_tensor_helper
(
x
,
i
shape
)
casted_x
=
self
.
_cast_to_tensor_helper
(
x
,
shape
=
global_
shape
)
return
casted_x
def
_cast_to_scalar_helper
(
self
,
x
):
...
...
@@ -2068,26 +2177,26 @@ class field(object):
# In all other cases, cast x directly
return
x
def
_cast_to_tensor_helper
(
self
,
x
,
i
shape
=
None
):
if
i
shape
is
None
:
i
shape
=
self
.
i
shape
def
_cast_to_tensor_helper
(
self
,
x
,
shape
=
None
):
if
shape
is
None
:
shape
=
self
.
get_global_
shape
()
# Check if x is a container of proper length
# containing something which will then checked by the domain-space
x_shape
=
np
.
shape
(
x
)
self_shape
=
self
.
domain
.
get
_shape
()
self_shape
=
self
.
get_global
_shape
()
try
:
container_Q
=
(
x
.
dtype
.
type
==
np
.
object_
)
except
(
AttributeError
):
container_Q
=
False
if
container_Q
:
if
x_shape
==
i
shape
:
if
x_shape
==
shape
:
return
x
elif
x_shape
==
i
shape
[:
len
(
x_shape
)]:
elif
x_shape
==
shape
[:
len
(
x_shape
)]:
return
x
.
reshape
(
x_shape
+
(
1
,)
*
(
len
(
i
shape
)
-
len
(
x_shape
)))
(
1
,)
*
(
len
(
shape
)
-
len
(
x_shape
)))
# Slow track: x could be a pure ndarray
...
...
@@ -2095,11 +2204,11 @@ class field(object):
# 1: There are cases where np.shape will only find the container
# although it was no np.object array; e.g. for [a,1].
# 2: The overall shape is already the right one
if
x_shape
==
i
shape
or
x_shape
==
(
i
shape
+
self_shape
):
if
x_shape
==
shape
or
x_shape
==
(
shape
+
self_shape
):
# Iterate over the outermost dimension and cast the inner spaces
result
=
np
.
empty
(
i
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
i
shape
)):
ii
=
np
.
unravel_index
(
i
,
i
shape
)
result
=
np
.
empty
(
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
shape
)):
ii
=
np
.
unravel_index
(
i
,
shape
)
try
:
result
[
ii
]
=
x
[
ii
]
except
(
TypeError
):
...
...
@@ -2112,16 +2221,16 @@ class field(object):
# Check if the input has shape (1, self.domain.shape)
# Iterate over the outermost dimension and cast the inner spaces
elif
x_shape
==
((
1
,)
+
self_shape
):
result
=
np
.
empty
(
i
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
i
shape
)):
ii
=
np
.
unravel_index
(
i
,
i
shape
)
result
=
np
.
empty
(
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
shape
)):
ii
=
np
.
unravel_index
(
i
,
shape
)
result
[
ii
]
=
x
[
0
]
# Case 4: fallback: try to cast x with self.domain
else
:
# Iterate over the outermost dimension and cast the inner spaces
result
=
np
.
empty
(
i
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
i
shape
)):
ii
=
np
.
unravel_index
(
i
,
i
shape
)
result
=
np
.
empty
(
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
shape
)):
ii
=
np
.
unravel_index
(
i
,
shape
)
result
[
ii
]
=
x
return
result
...
...
@@ -2903,4 +3012,4 @@ class field(object):
class
EmptyField
(
field
):
def
__init__
(
self
):
pass
\ No newline at end of file
pass
test/test_nifty_field.py
View file @
25f052c3
...
...
@@ -56,11 +56,11 @@ all_hp_datatypes = [np.dtype('float64')]
###############################################################################
DATAMODELS
=
{}
DATAMODELS
[
'point_space'
]
=
[
'np'
]
+
POINT_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'rg_space'
]
=
[
'np'
]
+
RG_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'lm_space'
]
=
[
'np'
]
+
LM_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'gl_space'
]
=
[
'np'
]
+
GL_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'hp_space'
]
=
[
'np'
]
+
HP_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'point_space'
]
=
POINT_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'rg_space'
]
=
RG_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'lm_space'
]
=
LM_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'gl_space'
]
=
GL_DISTRIBUTION_STRATEGIES
DATAMODELS
[
'hp_space'
]
=
HP_DISTRIBUTION_STRATEGIES
###############################################################################
...
...
@@ -110,10 +110,9 @@ class Test_field_init(unittest.TestCase):
@
parameterized
.
expand
(
space_list
)
def
test_successfull_init_and_attributes
(
self
,
s
):
s
=
s
[
0
]
f
=
field
(
s
)
assert
(
f
.
domain
is
s
)
assert
(
s
.
check_codomain
(
f
.
codomain
))
f
=
field
(
domain
=
np
.
array
([
s
]),
dtype
=
s
.
dtype
)
assert
(
f
.
domain
[
0
]
is
s
)
assert
(
s
.
check_codomain
(
f
.
codomain
[
0
]))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment