Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
37582f42
Commit
37582f42
authored
Feb 07, 2017
by
Theo Steininger
Browse files
Unified spaces and field_types into single domain object.
parent
e38eae92
Pipeline
#9997
passed with stage
in 33 minutes and 35 seconds
Changes
17
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/domain_object.py
0 → 100644
View file @
37582f42
# -*- coding: utf-8 -*-
import
abc
import
numpy
as
np
from
keepers
import
Loggable
,
\
Versionable
class
DomainObject
(
Versionable
,
Loggable
,
object
):
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
dtype
):
self
.
_dtype
=
np
.
dtype
(
dtype
)
self
.
_ignore_for_hash
=
[]
def
__hash__
(
self
):
# Extract the identifying parts from the vars(self) dict.
result_hash
=
0
for
key
in
sorted
(
vars
(
self
).
keys
()):
item
=
vars
(
self
)[
key
]
if
key
in
self
.
_ignore_for_hash
or
key
==
'_ignore_for_hash'
:
continue
result_hash
^=
item
.
__hash__
()
^
int
(
hash
(
key
)
/
117
)
return
result_hash
def
__eq__
(
self
,
x
):
if
isinstance
(
x
,
type
(
self
)):
return
hash
(
self
)
==
hash
(
x
)
else
:
return
False
def
__ne__
(
self
,
x
):
return
not
self
.
__eq__
(
x
)
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
abc
.
abstractproperty
def
shape
(
self
):
raise
NotImplementedError
(
"There is no generic shape for DomainObject."
)
@
abc
.
abstractproperty
def
dim
(
self
):
raise
NotImplementedError
(
"There is no generic dim for DomainObject."
)
def
pre_cast
(
self
,
x
,
axes
=
None
):
return
x
def
post_cast
(
self
,
x
,
axes
=
None
):
return
x
# ---Serialization---
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
repository
):
result
=
cls
(
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
]))
return
result
nifty/field.py
View file @
37582f42
...
@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\
...
@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\
from
nifty.config
import
nifty_configuration
as
gc
from
nifty.config
import
nifty_configuration
as
gc
from
nifty.
field_types
import
FieldType
from
nifty.
domain_object
import
DomainObject
from
nifty.spaces.space
import
Space
from
nifty.spaces.power_space
import
PowerSpace
from
nifty.spaces.power_space
import
PowerSpace
import
nifty.nifty_utilities
as
utilities
import
nifty.nifty_utilities
as
utilities
...
@@ -21,25 +20,15 @@ from nifty.random import Random
...
@@ -21,25 +20,15 @@ from nifty.random import Random
class
Field
(
Loggable
,
Versionable
,
object
):
class
Field
(
Loggable
,
Versionable
,
object
):
# ---Initialization methods---
# ---Initialization methods---
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
field_type
=
None
,
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
dtype
=
None
,
distribution_strategy
=
None
,
copy
=
False
):
distribution_strategy
=
None
,
copy
=
False
):
self
.
domain
=
self
.
_parse_domain
(
domain
=
domain
,
val
=
val
)
self
.
domain
=
self
.
_parse_domain
(
domain
=
domain
,
val
=
val
)
self
.
domain_axes
=
self
.
_get_axes_tuple
(
self
.
domain
)
self
.
domain_axes
=
self
.
_get_axes_tuple
(
self
.
domain
)
self
.
field_type
=
self
.
_parse_field_type
(
field_type
,
val
=
val
)
try
:
start
=
len
(
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain_axes
))
except
TypeError
:
start
=
0
self
.
field_type_axes
=
self
.
_get_axes_tuple
(
self
.
field_type
,
start
=
start
)
self
.
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
self
.
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
,
val
=
val
,
domain
=
self
.
domain
,
domain
=
self
.
domain
)
field_type
=
self
.
field_type
)
self
.
distribution_strategy
=
self
.
_parse_distribution_strategy
(
self
.
distribution_strategy
=
self
.
_parse_distribution_strategy
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
...
@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object):
...
@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object):
domain
=
val
.
domain
domain
=
val
.
domain
else
:
else
:
domain
=
()
domain
=
()
elif
isinstance
(
domain
,
Space
):
elif
isinstance
(
domain
,
DomainObject
):
domain
=
(
domain
,)
domain
=
(
domain
,)
elif
not
isinstance
(
domain
,
tuple
):
elif
not
isinstance
(
domain
,
tuple
):
domain
=
tuple
(
domain
)
domain
=
tuple
(
domain
)
for
d
in
domain
:
for
d
in
domain
:
if
not
isinstance
(
d
,
Space
):
if
not
isinstance
(
d
,
DomainObject
):
raise
TypeError
(
raise
TypeError
(
"Given domain contains something that is not a "
"Given domain contains something that is not a "
"
nifty.spa
ce."
)
"
DomainObject instan
ce."
)
return
domain
return
domain
def
_parse_field_type
(
self
,
field_type
,
val
=
None
):
if
field_type
is
None
:
if
isinstance
(
val
,
Field
):
field_type
=
val
.
field_type
else
:
field_type
=
()
elif
isinstance
(
field_type
,
FieldType
):
field_type
=
(
field_type
,)
elif
not
isinstance
(
field_type
,
tuple
):
field_type
=
tuple
(
field_type
)
for
ft
in
field_type
:
if
not
isinstance
(
ft
,
FieldType
):
raise
TypeError
(
"Given object is not a nifty.FieldType."
)
return
field_type
def
_get_axes_tuple
(
self
,
things_with_shape
,
start
=
0
):
def
_get_axes_tuple
(
self
,
things_with_shape
,
start
=
0
):
i
=
start
i
=
start
axes_list
=
[]
axes_list
=
[]
...
@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object):
...
@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object):
axes_list
+=
[
tuple
(
l
)]
axes_list
+=
[
tuple
(
l
)]
return
tuple
(
axes_list
)
return
tuple
(
axes_list
)
def
_infer_dtype
(
self
,
dtype
,
val
,
domain
,
field_type
):
def
_infer_dtype
(
self
,
dtype
,
val
,
domain
):
if
dtype
is
None
:
if
dtype
is
None
:
if
isinstance
(
val
,
Field
)
or
\
if
isinstance
(
val
,
Field
)
or
\
isinstance
(
val
,
distributed_data_object
):
isinstance
(
val
,
distributed_data_object
):
...
@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object):
...
@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple
=
(
np
.
dtype
(
dtype
),)
dtype_tuple
=
(
np
.
dtype
(
dtype
),)
if
domain
is
not
None
:
if
domain
is
not
None
:
dtype_tuple
+=
tuple
(
np
.
dtype
(
sp
.
dtype
)
for
sp
in
domain
)
dtype_tuple
+=
tuple
(
np
.
dtype
(
sp
.
dtype
)
for
sp
in
domain
)
if
field_type
is
not
None
:
dtype_tuple
+=
tuple
(
np
.
dtype
(
ft
.
dtype
)
for
ft
in
field_type
)
dtype
=
reduce
(
lambda
x
,
y
:
np
.
result_type
(
x
,
y
),
dtype_tuple
)
dtype
=
reduce
(
lambda
x
,
y
:
np
.
result_type
(
x
,
y
),
dtype_tuple
)
...
@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object):
...
@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object):
# ---Factory methods---
# ---Factory methods---
@
classmethod
@
classmethod
def
from_random
(
cls
,
random_type
,
domain
=
None
,
dtype
=
None
,
field_type
=
None
,
def
from_random
(
cls
,
random_type
,
domain
=
None
,
dtype
=
None
,
distribution_strategy
=
None
,
**
kwargs
):
distribution_strategy
=
None
,
**
kwargs
):
# create a initially empty field
# create a initially empty field
f
=
cls
(
domain
=
domain
,
dtype
=
dtype
,
field_type
=
field_type
,
f
=
cls
(
domain
=
domain
,
dtype
=
dtype
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
# now use the processed input in terms of f in order to parse the
# now use the processed input in terms of f in order to parse the
...
@@ -363,7 +334,6 @@ class Field(Loggable, Versionable, object):
...
@@ -363,7 +334,6 @@ class Field(Loggable, Versionable, object):
std
=
std
,
std
=
std
,
domain
=
result_domain
,
domain
=
result_domain
,
dtype
=
harmonic_domain
.
dtype
,
dtype
=
harmonic_domain
.
dtype
,
field_type
=
self
.
field_type
,
distribution_strategy
=
self
.
distribution_strategy
)
distribution_strategy
=
self
.
distribution_strategy
)
for
x
in
result_list
]
for
x
in
result_list
]
...
@@ -451,9 +421,7 @@ class Field(Loggable, Versionable, object):
...
@@ -451,9 +421,7 @@ class Field(Loggable, Versionable, object):
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
shape_tuple
=
()
shape_tuple
=
tuple
(
sp
.
shape
for
sp
in
self
.
domain
)
shape_tuple
+=
tuple
(
sp
.
shape
for
sp
in
self
.
domain
)
shape_tuple
+=
tuple
(
ft
.
shape
for
ft
in
self
.
field_type
)
try
:
try
:
global_shape
=
reduce
(
lambda
x
,
y
:
x
+
y
,
shape_tuple
)
global_shape
=
reduce
(
lambda
x
,
y
:
x
+
y
,
shape_tuple
)
except
TypeError
:
except
TypeError
:
...
@@ -463,9 +431,7 @@ class Field(Loggable, Versionable, object):
...
@@ -463,9 +431,7 @@ class Field(Loggable, Versionable, object):
@
property
@
property
def
dim
(
self
):
def
dim
(
self
):
dim_tuple
=
()
dim_tuple
=
tuple
(
sp
.
dim
for
sp
in
self
.
domain
)
dim_tuple
+=
tuple
(
sp
.
dim
for
sp
in
self
.
domain
)
dim_tuple
+=
tuple
(
ft
.
dim
for
ft
in
self
.
field_type
)
try
:
try
:
return
reduce
(
lambda
x
,
y
:
x
*
y
,
dim_tuple
)
return
reduce
(
lambda
x
,
y
:
x
*
y
,
dim_tuple
)
except
TypeError
:
except
TypeError
:
...
@@ -500,20 +466,12 @@ class Field(Loggable, Versionable, object):
...
@@ -500,20 +466,12 @@ class Field(Loggable, Versionable, object):
casted_x
=
sp
.
pre_cast
(
casted_x
,
casted_x
=
sp
.
pre_cast
(
casted_x
,
axes
=
self
.
domain_axes
[
ind
])
axes
=
self
.
domain_axes
[
ind
])
for
ind
,
ft
in
enumerate
(
self
.
field_type
):
casted_x
=
ft
.
pre_cast
(
casted_x
,
axes
=
self
.
field_type_axes
[
ind
])
casted_x
=
self
.
_actual_cast
(
casted_x
,
dtype
=
dtype
)
casted_x
=
self
.
_actual_cast
(
casted_x
,
dtype
=
dtype
)
for
ind
,
sp
in
enumerate
(
self
.
domain
):
for
ind
,
sp
in
enumerate
(
self
.
domain
):
casted_x
=
sp
.
post_cast
(
casted_x
,
casted_x
=
sp
.
post_cast
(
casted_x
,
axes
=
self
.
domain_axes
[
ind
])
axes
=
self
.
domain_axes
[
ind
])
for
ind
,
ft
in
enumerate
(
self
.
field_type
):
casted_x
=
ft
.
post_cast
(
casted_x
,
axes
=
self
.
field_type_axes
[
ind
])
return
casted_x
return
casted_x
def
_actual_cast
(
self
,
x
,
dtype
=
None
):
def
_actual_cast
(
self
,
x
,
dtype
=
None
):
...
@@ -530,19 +488,16 @@ class Field(Loggable, Versionable, object):
...
@@ -530,19 +488,16 @@ class Field(Loggable, Versionable, object):
return_x
.
set_full_data
(
x
,
copy
=
False
)
return_x
.
set_full_data
(
x
,
copy
=
False
)
return
return_x
return
return_x
def
copy
(
self
,
domain
=
None
,
dtype
=
None
,
field_type
=
None
,
def
copy
(
self
,
domain
=
None
,
dtype
=
None
,
distribution_strategy
=
None
):
distribution_strategy
=
None
):
copied_val
=
self
.
get_val
(
copy
=
True
)
copied_val
=
self
.
get_val
(
copy
=
True
)
new_field
=
self
.
copy_empty
(
new_field
=
self
.
copy_empty
(
domain
=
domain
,
domain
=
domain
,
dtype
=
dtype
,
dtype
=
dtype
,
field_type
=
field_type
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
new_field
.
set_val
(
new_val
=
copied_val
,
copy
=
False
)
new_field
.
set_val
(
new_val
=
copied_val
,
copy
=
False
)
return
new_field
return
new_field
def
copy_empty
(
self
,
domain
=
None
,
dtype
=
None
,
field_type
=
None
,
def
copy_empty
(
self
,
domain
=
None
,
dtype
=
None
,
distribution_strategy
=
None
):
distribution_strategy
=
None
):
if
domain
is
None
:
if
domain
is
None
:
domain
=
self
.
domain
domain
=
self
.
domain
else
:
else
:
...
@@ -553,11 +508,6 @@ class Field(Loggable, Versionable, object):
...
@@ -553,11 +508,6 @@ class Field(Loggable, Versionable, object):
else
:
else
:
dtype
=
np
.
dtype
(
dtype
)
dtype
=
np
.
dtype
(
dtype
)
if
field_type
is
None
:
field_type
=
self
.
field_type
else
:
field_type
=
self
.
_parse_field_type
(
field_type
)
if
distribution_strategy
is
None
:
if
distribution_strategy
is
None
:
distribution_strategy
=
self
.
distribution_strategy
distribution_strategy
=
self
.
distribution_strategy
...
@@ -567,10 +517,6 @@ class Field(Loggable, Versionable, object):
...
@@ -567,10 +517,6 @@ class Field(Loggable, Versionable, object):
if
self
.
domain
[
i
]
is
not
domain
[
i
]:
if
self
.
domain
[
i
]
is
not
domain
[
i
]:
fast_copyable
=
False
fast_copyable
=
False
break
break
for
i
in
xrange
(
len
(
self
.
field_type
)):
if
self
.
field_type
[
i
]
is
not
field_type
[
i
]:
fast_copyable
=
False
break
except
IndexError
:
except
IndexError
:
fast_copyable
=
False
fast_copyable
=
False
...
@@ -580,7 +526,6 @@ class Field(Loggable, Versionable, object):
...
@@ -580,7 +526,6 @@ class Field(Loggable, Versionable, object):
else
:
else
:
new_field
=
Field
(
domain
=
domain
,
new_field
=
Field
(
domain
=
domain
,
dtype
=
dtype
,
dtype
=
dtype
,
field_type
=
field_type
,
distribution_strategy
=
distribution_strategy
)
distribution_strategy
=
distribution_strategy
)
return
new_field
return
new_field
...
@@ -626,8 +571,6 @@ class Field(Loggable, Versionable, object):
...
@@ -626,8 +571,6 @@ class Field(Loggable, Versionable, object):
assert
len
(
x
.
domain
)
==
len
(
self
.
domain
)
assert
len
(
x
.
domain
)
==
len
(
self
.
domain
)
for
index
in
xrange
(
len
(
self
.
domain
)):
for
index
in
xrange
(
len
(
self
.
domain
)):
assert
x
.
domain
[
index
]
==
self
.
domain
[
index
]
assert
x
.
domain
[
index
]
==
self
.
domain
[
index
]
for
index
in
xrange
(
len
(
self
.
field_type
)):
assert
x
.
field_type
[
index
]
==
self
.
field_type
[
index
]
except
AssertionError
:
except
AssertionError
:
raise
ValueError
(
raise
ValueError
(
"domains are incompatible."
)
"domains are incompatible."
)
...
@@ -707,22 +650,15 @@ class Field(Loggable, Versionable, object):
...
@@ -707,22 +650,15 @@ class Field(Loggable, Versionable, object):
return_field
.
set_val
(
new_val
,
copy
=
False
)
return_field
.
set_val
(
new_val
,
copy
=
False
)
return
return_field
return
return_field
def
_contraction_helper
(
self
,
op
,
spaces
,
types
):
def
_contraction_helper
(
self
,
op
,
spaces
):
# build a list of all axes
# build a list of all axes
if
spaces
is
None
:
if
spaces
is
None
:
spaces
=
xrange
(
len
(
self
.
domain
))
spaces
=
xrange
(
len
(
self
.
domain
))
else
:
else
:
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
types
is
None
:
axes_list
=
tuple
(
self
.
domain_axes
[
sp_index
]
for
sp_index
in
spaces
)
types
=
xrange
(
len
(
self
.
field_type
))
else
:
types
=
utilities
.
cast_axis_to_tuple
(
types
,
len
(
self
.
field_type
))
axes_list
=
()
axes_list
+=
tuple
(
self
.
domain_axes
[
sp_index
]
for
sp_index
in
spaces
)
axes_list
+=
tuple
(
self
.
field_type_axes
[
ft_index
]
for
ft_index
in
types
)
try
:
try
:
axes_list
=
reduce
(
lambda
x
,
y
:
x
+
y
,
axes_list
)
axes_list
=
reduce
(
lambda
x
,
y
:
x
+
y
,
axes_list
)
except
TypeError
:
except
TypeError
:
...
@@ -739,47 +675,44 @@ class Field(Loggable, Versionable, object):
...
@@ -739,47 +675,44 @@ class Field(Loggable, Versionable, object):
return_domain
=
tuple
(
self
.
domain
[
i
]
return_domain
=
tuple
(
self
.
domain
[
i
]
for
i
in
xrange
(
len
(
self
.
domain
))
for
i
in
xrange
(
len
(
self
.
domain
))
if
i
not
in
spaces
)
if
i
not
in
spaces
)
return_field_type
=
tuple
(
self
.
field_type
[
i
]
for
i
in
xrange
(
len
(
self
.
field_type
))
if
i
not
in
types
)
return_field
=
Field
(
domain
=
return_domain
,
return_field
=
Field
(
domain
=
return_domain
,
val
=
data
,
val
=
data
,
field_type
=
return_field_type
,
copy
=
False
)
copy
=
False
)
return
return_field
return
return_field
def
sum
(
self
,
spaces
=
None
,
types
=
None
):
def
sum
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'sum'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'sum'
,
spaces
)
def
prod
(
self
,
spaces
=
None
,
types
=
None
):
def
prod
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'prod'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'prod'
,
spaces
)
def
all
(
self
,
spaces
=
None
,
types
=
None
):
def
all
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'all'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'all'
,
spaces
)
def
any
(
self
,
spaces
=
None
,
types
=
None
):
def
any
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'any'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'any'
,
spaces
)
def
min
(
self
,
spaces
=
None
,
types
=
None
):
def
min
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'min'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'min'
,
spaces
)
def
nanmin
(
self
,
spaces
=
None
,
types
=
None
):
def
nanmin
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'nanmin'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'nanmin'
,
spaces
)
def
max
(
self
,
spaces
=
None
,
types
=
None
):
def
max
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'max'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'max'
,
spaces
)
def
nanmax
(
self
,
spaces
=
None
,
types
=
None
):
def
nanmax
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'nanmax'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'nanmax'
,
spaces
)
def
mean
(
self
,
spaces
=
None
,
types
=
None
):
def
mean
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'mean'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'mean'
,
spaces
)
def
var
(
self
,
spaces
=
None
,
types
=
None
):
def
var
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'var'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'var'
,
spaces
)
def
std
(
self
,
spaces
=
None
,
types
=
None
):
def
std
(
self
,
spaces
=
None
):
return
self
.
_contraction_helper
(
'std'
,
spaces
,
types
)
return
self
.
_contraction_helper
(
'std'
,
spaces
)
# ---General binary methods---
# ---General binary methods---
...
@@ -790,9 +723,6 @@ class Field(Loggable, Versionable, object):
...
@@ -790,9 +723,6 @@ class Field(Loggable, Versionable, object):
assert
len
(
other
.
domain
)
==
len
(
self
.
domain
)
assert
len
(
other
.
domain
)
==
len
(
self
.
domain
)
for
index
in
xrange
(
len
(
self
.
domain
)):
for
index
in
xrange
(
len
(
self
.
domain
)):
assert
other
.
domain
[
index
]
==
self
.
domain
[
index
]
assert
other
.
domain
[
index
]
==
self
.
domain
[
index
]
assert
len
(
other
.
field_type
)
==
len
(
self
.
field_type
)
for
index
in
xrange
(
len
(
self
.
field_type
)):
assert
other
.
field_type
[
index
]
==
self
.
field_type
[
index
]
except
AssertionError
:
except
AssertionError
:
raise
ValueError
(
raise
ValueError
(
"domains are incompatible."
)
"domains are incompatible."
)
...
@@ -895,19 +825,14 @@ class Field(Loggable, Versionable, object):
...
@@ -895,19 +825,14 @@ class Field(Loggable, Versionable, object):
def
_to_hdf5
(
self
,
hdf5_group
):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
hdf5_group
.
attrs
[
'distribution_strategy'
]
=
self
.
distribution_strategy
hdf5_group
.
attrs
[
'distribution_strategy'
]
=
self
.
distribution_strategy
hdf5_group
.
attrs
[
'field_type_axes'
]
=
str
(
self
.
field_type_axes
)
hdf5_group
.
attrs
[
'domain_axes'
]
=
str
(
self
.
domain_axes
)
hdf5_group
.
attrs
[
'domain_axes'
]
=
str
(
self
.
domain_axes
)
hdf5_group
[
'num_domain'
]
=
len
(
self
.
domain
)
hdf5_group
[
'num_domain'
]
=
len
(
self
.
domain
)
hdf5_group
[
'num_ft'
]
=
len
(
self
.
field_type
)
ret_dict
=
{
'val'
:
self
.
val
}
ret_dict
=
{
'val'
:
self
.
val
}
for
i
in
range
(
len
(
self
.
domain
)):
for
i
in
range
(
len
(
self
.
domain
)):
ret_dict
[
's_'
+
str
(
i
)]
=
self
.
domain
[
i
]
ret_dict
[
's_'
+
str
(
i
)]
=
self
.
domain
[
i
]
for
i
in
range
(
len
(
self
.
field_type
)):
ret_dict
[
'ft_'
+
str
(
i
)]
=
self
.
field_type
[
i
]
return
ret_dict
return
ret_dict
@
classmethod
@
classmethod
...
@@ -922,14 +847,7 @@ class Field(Loggable, Versionable, object):
...
@@ -922,14 +847,7 @@ class Field(Loggable, Versionable, object):
temp_domain
.
append
(
repository
.
get
(
's_'
+
str
(
i
),
hdf5_group
))
temp_domain
.
append
(
repository
.
get
(
's_'
+
str
(
i
),
hdf5_group
))
new_field
.
domain
=
tuple
(
temp_domain
)
new_field
.
domain
=
tuple
(
temp_domain
)
temp_ft
=
[]
for
i
in
range
(
hdf5_group
[
'num_ft'
][()]):
temp_domain
.
append
(
repository
.
get
(
'ft_'
+
str
(
i
),
hdf5_group
))
new_field
.
field_type
=
tuple
(
temp_ft
)
exec
(
'new_field.domain_axes = '
+
hdf5_group
.
attrs
[
'domain_axes'
])
exec
(
'new_field.domain_axes = '
+
hdf5_group
.
attrs
[
'domain_axes'
])
exec
(
'new_field.field_type_axes = '
+
hdf5_group
.
attrs
[
'field_type_axes'
])
new_field
.
_val
=
repository
.
get
(
'val'
,
hdf5_group
)
new_field
.
_val
=
repository
.
get
(
'val'
,
hdf5_group
)
new_field
.
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
new_field
.
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
new_field
.
distribution_strategy
=
\
new_field
.
distribution_strategy
=
\
...
...
nifty/field_types/field_array.py
View file @
37582f42
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
pickle
from
field_type
import
FieldType
from
field_type
import
FieldType
class
FieldArray
(
FieldType
):
class
FieldArray
(
FieldType
):
def
__init__
(
self
,
dtype
,
shape
):
try
:
new_shape
=
tuple
([
int
(
i
)
for
i
in
shape
])
except
TypeError
:
new_shape
=
(
int
(
shape
),
)
self
.
_shape
=
new_shape
super
(
FieldArray
,
self
).
__init__
(
dtype
=
dtype
)
@
property
def
shape
(
self
):
return
self
.
_shape
@
property
@
property
def
dim
(
self
):
def
dim
(
self
):
return
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
shape
)
return
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
shape
)
# ---Serialization---
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'shape'
]
=
self
.
shape
hdf5_group
[
'dtype'
]
=
pickle
.
dumps
(
self
.
dtype
)
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
loopback_get
):
result
=
cls
(
hdf5_group
[
'shape'
][:],
pickle
.
loads
(
hdf5_group
[
'dtype'
][()])
)