Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
d8b9b918
Commit
d8b9b918
authored
May 19, 2016
by
csongor
Browse files
WIP: Field support for multiple spaces.
parent
25f052c3
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty_core.py
View file @
d8b9b918
...
...
@@ -1994,7 +1994,11 @@ class field(object):
return
gotten
def
get_shape
(
self
):
return
self
.
domain
.
get_shape
()
global_shape
=
np
.
sum
([
space
.
get_shape
()
for
space
in
self
.
domain
])
if
isinstance
(
global_shape
,
tuple
):
return
global_shape
else
:
return
()
def
get_dim
(
self
,
split
=
False
):
"""
...
...
@@ -2013,37 +2017,23 @@ class field(object):
Dimension of space.
"""
return
self
.
domain
.
get_dim
(
split
=
split
)
return
np
.
prod
(
np
.
sum
([
space
.
get_shape
()
for
space
in
self
.
domain
])
)
def
get_dof
(
self
,
split
=
False
):
return
self
.
domain
.
get_dof
(
split
=
split
)
def
get_ishape
(
self
):
return
self
.
ishape
def
get_global_shape
(
self
):
global_shape
=
np
.
sum
([
space
.
get_shape
()
for
space
in
self
.
domain
])
if
isinstance
(
global_shape
,
tuple
):
return
global_shape
else
:
return
()
return
np
.
sum
([
len
(
space
.
get_shape
())
for
space
in
self
.
domain
])
def
_map
(
self
,
function
,
*
args
):
return
utilities
.
field_map
(
self
.
get_
global_
shape
(),
function
,
*
args
)
return
utilities
.
field_map
(
self
.
get_shape
(),
function
,
*
args
)
def
cast
(
self
,
x
=
None
,
dtype
=
None
):
if
dtype
is
not
None
:
dtype
=
np
.
dtype
(
dtype
)
if
dtype
is
None
:
dtype
=
self
.
dtype
casted_x
=
self
.
_cast_to_shape
(
x
)
if
self
.
get_global_shape
()
==
():
return
self
.
domain
.
cast
(
casted_x
)
else
:
return
self
.
_map
(
lambda
z
:
self
.
domain
.
cast
(
z
),
casted_x
)
casted_x
=
self
.
_cast_to_d2o
(
x
,
dtype
=
dtype
)
return
self
.
_complement_cast
(
casted_x
)
def
_cast_to_d2o
(
self
,
x
,
dtype
=
None
,
**
kwargs
):
def
_cast_to_d2o
(
self
,
x
,
dtype
=
None
,
shape
=
None
,
**
kwargs
):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
...
...
@@ -2066,29 +2056,34 @@ class field(object):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if
isinstance
(
x
,
field
):
x
=
x
.
get_val
()
if
dtype
is
None
:
dtype
=
self
.
dtype
if
shape
is
None
:
shape
=
self
.
get_shape
()
# Case 1: x is a distributed_data_object
if
isinstance
(
x
,
distributed_data_object
):
to_copy
=
False
# Check the shape
if
np
.
any
(
np
.
array
(
x
.
shape
)
!=
np
.
array
(
self
.
get_
shape
()
)):
if
np
.
any
(
np
.
array
(
x
.
shape
)
!=
np
.
array
(
shape
)):
# Check if at least the number of degrees of freedom is equal
if
x
.
get_dim
()
==
self
.
get_dim
():
try
:
temp
=
x
.
copy_empty
(
global_shape
=
self
.
get_
shape
()
)
temp
.
set_local_data
(
x
.
get_local_data
()
,
copy
=
False
)
temp
=
x
.
copy_empty
(
global_shape
=
shape
)
temp
.
set_local_data
(
x
,
copy
=
False
)
except
:
# If the number of dof is equal or 1, use np.reshape...
about
.
warnings
.
cflush
(
"WARNING: Trying to reshape the data. This "
+
"operation is expensive as it consolidates the "
+
"full data!
\n
"
)
temp
=
x
.
get_full_data
()
temp
=
np
.
reshape
(
temp
,
self
.
get_
shape
()
)
temp
=
x
temp
=
np
.
reshape
(
temp
,
shape
)
# ... and cast again
return
self
.
_cast_to_d2o
(
temp
,
dtype
=
dtype
,
...
...
@@ -2133,108 +2128,10 @@ class field(object):
# Cast the d2o
return
self
.
cast
(
x
,
dtype
=
dtype
)
def
_cast_to_shape
(
self
,
x
):
if
isinstance
(
x
,
field
):
x
=
x
.
get_val
()
global_shape
=
self
.
get_global_shape
()
if
global_shape
==
():
casted_x
=
self
.
_cast_to_scalar_helper
(
x
)
else
:
casted_x
=
self
.
_cast_to_tensor_helper
(
x
,
shape
=
global_shape
)
return
casted_x
def
_cast_to_scalar_helper
(
self
,
x
):
# if x is already a scalar or does fit directly, return it
self_shape
=
self
.
domain
.
get_shape
()
x_shape
=
np
.
shape
(
x
)
if
np
.
isscalar
(
x
)
or
x_shape
==
self_shape
:
return
x
# check if the given object is a 'container'
try
:
container_Q
=
(
x
.
dtype
.
type
==
np
.
object_
)
except
(
AttributeError
):
container_Q
=
False
if
container_Q
:
# extract the first element. This works on 0-d ndarrays, too.
result
=
x
[(
0
,)
*
len
(
x_shape
)]
return
result
# if x is no container-type, it could be that the needed shape
# for self.domain is encapsulated in x
if
x_shape
[
len
(
x_shape
)
-
len
(
self_shape
):]
==
self_shape
:
if
x_shape
[:
len
(
x_shape
)
-
len
(
self_shape
)]
!=
(
1
,):
about
.
warnings
.
cprint
(
"WARNING: discarding all internal dimensions "
+
"except for the first one."
)
result
=
x
for
i
in
xrange
(
len
(
x_shape
)
-
len
(
self_shape
)):
result
=
result
[
0
]
return
result
# In all other cases, cast x directly
def
_complement_cast
(
self
,
x
):
#TODO implement complement cast for multiple spaces.
return
x
def
_cast_to_tensor_helper
(
self
,
x
,
shape
=
None
):
if
shape
is
None
:
shape
=
self
.
get_global_shape
()
# Check if x is a container of proper length
# containing something which will then checked by the domain-space
x_shape
=
np
.
shape
(
x
)
self_shape
=
self
.
get_global_shape
()
try
:
container_Q
=
(
x
.
dtype
.
type
==
np
.
object_
)
except
(
AttributeError
):
container_Q
=
False
if
container_Q
:
if
x_shape
==
shape
:
return
x
elif
x_shape
==
shape
[:
len
(
x_shape
)]:
return
x
.
reshape
(
x_shape
+
(
1
,)
*
(
len
(
shape
)
-
len
(
x_shape
)))
# Slow track: x could be a pure ndarray
# Case 1 and 2:
# 1: There are cases where np.shape will only find the container
# although it was no np.object array; e.g. for [a,1].
# 2: The overall shape is already the right one
if
x_shape
==
shape
or
x_shape
==
(
shape
+
self_shape
):
# Iterate over the outermost dimension and cast the inner spaces
result
=
np
.
empty
(
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
shape
)):
ii
=
np
.
unravel_index
(
i
,
shape
)
try
:
result
[
ii
]
=
x
[
ii
]
except
(
TypeError
):
extracted
=
x
for
j
in
xrange
(
len
(
ii
)):
extracted
=
extracted
[
ii
[
j
]]
result
[
ii
]
=
extracted
# Case 3: The overall shape does not match directly.
# Check if the input has shape (1, self.domain.shape)
# Iterate over the outermost dimension and cast the inner spaces
elif
x_shape
==
((
1
,)
+
self_shape
):
result
=
np
.
empty
(
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
shape
)):
ii
=
np
.
unravel_index
(
i
,
shape
)
result
[
ii
]
=
x
[
0
]
# Case 4: fallback: try to cast x with self.domain
else
:
# Iterate over the outermost dimension and cast the inner spaces
result
=
np
.
empty
(
shape
,
dtype
=
np
.
object
)
for
i
in
xrange
(
np
.
prod
(
shape
)):
ii
=
np
.
unravel_index
(
i
,
shape
)
result
[
ii
]
=
x
return
result
def
set_domain
(
self
,
new_domain
=
None
,
force
=
False
):
"""
Resets the codomain of the field.
...
...
test/test_nifty_field.py
View file @
d8b9b918
...
...
@@ -110,7 +110,7 @@ class Test_field_init(unittest.TestCase):
@
parameterized
.
expand
(
space_list
)
def
test_successfull_init_and_attributes
(
self
,
s
):
f
=
field
(
domain
=
np
.
array
([
s
]),
dtype
=
s
.
dtype
)
f
=
field
(
domain
=
np
.
array
([
s
]),
dtype
=
s
.
dtype
,
datamodel
=
s
.
datamodel
)
assert
(
f
.
domain
[
0
]
is
s
)
assert
(
s
.
check_codomain
(
f
.
codomain
[
0
]))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment