Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
5db4d734
Commit
5db4d734
authored
May 29, 2016
by
csongor
Browse files
WIP: fix field.__init__ for multiple spaces
parent
7714b6cb
Pipeline
#3933
skipped
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty_field.py
View file @
5db4d734
...
...
@@ -2,12 +2,12 @@ from __future__ import division
import
numpy
as
np
import
pylab
as
pl
from
d2o
import
distributed_data_object
,
\
STRATEGIES
as
DISTRIBUTION_STRATEGIES
from
d2o
import
distributed_data_object
,
\
STRATEGIES
as
DISTRIBUTION_STRATEGIES
from
nifty.config
import
about
,
\
nifty_configuration
as
gc
,
\
dependency_injector
as
gdi
from
nifty.config
import
about
,
\
nifty_configuration
as
gc
,
\
dependency_injector
as
gdi
from
nifty.nifty_core
import
space
...
...
@@ -104,7 +104,7 @@ class field(object):
def
__init__
(
self
,
domain
=
None
,
val
=
None
,
codomain
=
None
,
comm
=
gc
[
'default_comm'
],
copy
=
False
,
dtype
=
np
.
dtype
(
'float64'
),
datamodel
=
'not'
,
**
kwargs
):
**
kwargs
):
"""
Sets the attributes for a field class instance.
...
...
@@ -256,12 +256,12 @@ class field(object):
if
len
(
domain
)
==
1
:
return
(
domain
[
0
].
get_codomain
(),)
else
:
# TODO implement for multiple domain get_codomain need
# calc_transform
return
np
.
empty
((
0
,))
codomain
=
tuple
(
space
.
get_codomain
()
for
space
in
domain
)
self
.
codomain
=
codomain
return
codomain
def
__len__
(
self
):
return
int
(
self
.
get_dim
(
split
=
True
)[
0
])
return
int
(
self
.
get_dim
()[
0
])
def
apply_scalar_function
(
self
,
function
,
inplace
=
False
):
if
inplace
:
...
...
@@ -312,11 +312,11 @@ class field(object):
datamodel
=
self
.
datamodel
if
(
domain
is
self
.
domain
and
codomain
is
self
.
codomain
and
dtype
==
self
.
dtype
and
comm
==
self
.
comm
and
datamodel
==
self
.
datamodel
and
kwargs
==
{}):
codomain
is
self
.
codomain
and
dtype
==
self
.
dtype
and
comm
==
self
.
comm
and
datamodel
==
self
.
datamodel
and
kwargs
==
{}):
new_field
=
self
.
_fast_copy_empty
()
else
:
new_field
=
field
(
domain
=
domain
,
codomain
=
codomain
,
dtype
=
dtype
,
...
...
@@ -336,8 +336,8 @@ class field(object):
if
new_val
is
not
None
:
if
copy
:
new_val
=
map
(
lambda
z
:
self
.
unary_operation
(
z
,
'copy'
),
new_val
)
lambda
z
:
self
.
unary_operation
(
z
,
'copy'
),
new_val
)
self
.
val
=
map
(
lambda
z
:
self
.
cast
(
z
),
new_val
)
return
self
.
val
...
...
@@ -348,7 +348,7 @@ class field(object):
def
__getitem__
(
self
,
key
):
if
np
.
isscalar
(
key
)
==
True
or
isinstance
(
key
,
slice
):
key
=
(
key
,
)
key
=
(
key
,)
if
self
.
ishape
==
():
return
self
.
domain
.
getitem
(
self
.
get_val
(),
key
)
else
:
...
...
@@ -362,7 +362,7 @@ class field(object):
if
is_data_container
:
gotten
=
map
(
lambda
z
:
self
.
domain
.
getitem
(
z
,
key
[
len
(
self
.
ishape
):]),
z
,
key
[
len
(
self
.
ishape
):]),
gotten
)
else
:
gotten
=
self
.
domain
.
getitem
(
gotten
,
...
...
@@ -371,7 +371,7 @@ class field(object):
def
__setitem__
(
self
,
key
,
value
):
if
np
.
isscalar
(
key
)
or
isinstance
(
key
,
slice
):
key
=
(
key
,
)
key
=
(
key
,)
if
self
.
ishape
==
():
return
self
.
domain
.
setitem
(
self
.
get_val
(),
value
,
key
)
else
:
...
...
@@ -385,7 +385,7 @@ class field(object):
if
is_data_container
:
gotten
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
setitem
(
z1
,
z2
,
key
[
len
(
self
.
ishape
):]),
z1
,
z2
,
key
[
len
(
self
.
ishape
):]),
gotten
,
value
)
else
:
gotten
=
self
.
domain
.
setitem
(
gotten
,
value
,
...
...
@@ -393,13 +393,13 @@ class field(object):
else
:
dummy
=
np
.
empty
(
self
.
ishape
)
gotten
=
self
.
val
.
__setitem__
(
key
,
self
.
cast
(
value
,
ishape
=
np
.
shape
(
dummy
[
key
])))
value
,
ishape
=
np
.
shape
(
dummy
[
key
])))
return
gotten
def
get_shape
(
self
):
if
len
(
self
.
domain
)
>
1
:
global_shape
=
reduce
(
lambda
x
,
y
:
x
.
get_shape
()
+
y
.
get_shape
(),
self
.
domain
)
global_shape
=
reduce
(
lambda
x
,
y
:
x
.
get_shape
()
+
y
.
get_shape
(),
self
.
domain
)
else
:
global_shape
=
self
.
domain
[
0
].
get_shape
()
...
...
@@ -408,7 +408,7 @@ class field(object):
else
:
return
()
def
get_dim
(
self
,
split
=
False
):
def
get_dim
(
self
):
"""
Computes the (array) dimension of the underlying space.
...
...
@@ -430,7 +430,7 @@ class field(object):
def
get_dof
(
self
,
split
=
False
):
dim
=
self
.
get_dim
()
if
np
.
issubdtype
(
self
.
dtype
,
np
.
complex
):
return
2
*
dim
return
2
*
dim
else
:
return
dim
...
...
@@ -541,7 +541,7 @@ class field(object):
return
self
.
cast
(
x
,
dtype
=
dtype
)
def
_complement_cast
(
self
,
x
):
#TODO implement complement cast for multiple spaces.
#
TODO implement complement cast for multiple spaces.
return
x
def
set_domain
(
self
,
new_domain
=
None
,
force
=
False
):
...
...
@@ -559,7 +559,7 @@ class field(object):
if
new_domain
is
None
:
new_domain
=
self
.
codomain
.
get_codomain
()
elif
not
force
:
assert
(
self
.
codomain
.
check_codomain
(
new_domain
))
assert
(
self
.
codomain
.
check_codomain
(
new_domain
))
self
.
domain
=
new_domain
return
self
.
domain
...
...
@@ -578,7 +578,7 @@ class field(object):
if
new_codomain
is
None
:
new_codomain
=
self
.
domain
.
get_codomain
()
elif
not
force
:
assert
(
self
.
domain
.
check_codomain
(
new_codomain
))
assert
(
self
.
domain
.
check_codomain
(
new_codomain
))
self
.
codomain
=
new_codomain
return
self
.
codomain
...
...
@@ -609,7 +609,7 @@ class field(object):
new_field
=
self
.
copy_empty
()
new_val
=
map
(
lambda
y
:
self
.
domain
.
calc_weight
(
y
,
power
=
power
),
self
.
get_val
())
self
.
get_val
())
new_field
.
set_val
(
new_val
=
new_val
)
return
new_field
...
...
@@ -630,9 +630,9 @@ class field(object):
"""
if
q
==
0.5
:
return
(
self
.
dot
(
x
=
self
))
**
(
1
/
2
)
return
(
self
.
dot
(
x
=
self
))
**
(
1
/
2
)
else
:
return
self
.
dot
(
x
=
self
**
(
q
-
1
))
**
(
1
/
q
)
return
self
.
dot
(
x
=
self
**
(
q
-
1
))
**
(
1
/
q
)
def
dot
(
self
,
x
=
None
,
axis
=
None
,
bare
=
False
):
"""
...
...
@@ -788,7 +788,7 @@ class field(object):
else
:
new_codomain
=
new_domain
.
get_codomain
()
else
:
assert
(
new_domain
.
check_codomain
(
new_codomain
))
assert
(
new_domain
.
check_codomain
(
new_codomain
))
new_val
=
map
(
lambda
z
:
self
.
domain
.
calc_transform
(
...
...
@@ -882,7 +882,7 @@ class field(object):
Returns the power spectrum.
"""
if
(
"codomain"
in
kwargs
):
if
(
"codomain"
in
kwargs
):
kwargs
.
__delitem__
(
"codomain"
)
about
.
warnings
.
cprint
(
"WARNING: codomain was removed from kwargs."
)
...
...
@@ -1013,12 +1013,12 @@ class field(object):
minmax
=
[
self
.
min
(),
self
.
max
()]
mean
=
self
.
mean
()
return
"nifty_core.field instance
\n
- domain = "
+
\
repr
(
self
.
domain
)
+
\
"
\n
- val = "
+
repr
(
self
.
get_val
())
+
\
"
\n
- min.,max. = "
+
str
(
minmax
)
+
\
"
\n
- mean = "
+
str
(
mean
)
+
\
"
\n
- codomain = "
+
repr
(
self
.
codomain
)
+
\
"
\n
- ishape = "
+
str
(
self
.
ishape
)
repr
(
self
.
domain
)
+
\
"
\n
- val = "
+
repr
(
self
.
get_val
())
+
\
"
\n
- min.,max. = "
+
str
(
minmax
)
+
\
"
\n
- mean = "
+
str
(
mean
)
+
\
"
\n
- codomain = "
+
repr
(
self
.
codomain
)
+
\
"
\n
- ishape = "
+
str
(
self
.
ishape
)
def
_unary_helper
(
self
,
x
,
op
,
**
kwargs
):
result
=
map
(
...
...
@@ -1233,13 +1233,10 @@ class field(object):
other_val
=
other
# bring other_val into the right shape
if
self
.
ishape
==
():
other_val
=
self
.
_cast_to_scalar_helper
(
other_val
)
else
:
other_val
=
self
.
_cast_to_tensor_helper
(
other_val
)
other_val
=
self
.
_cast_to_d2o
(
other_val
)
new_val
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
binary_operation
(
z1
,
z2
,
op
=
op
,
cast
=
0
),
lambda
z1
,
z2
:
self
.
binary_operation
(
z1
,
z2
,
op
=
op
,
cast
=
0
),
self
.
get_val
(),
other_val
)
...
...
@@ -1251,8 +1248,82 @@ class field(object):
working_field
.
set_val
(
new_val
=
new_val
)
return
working_field
def
unary_operation
(
self
,
x
,
op
=
'None'
,
axis
=
None
,
**
kwargs
):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
"""
translation
=
{
'pos'
:
lambda
y
:
getattr
(
y
,
'__pos__'
)(),
'neg'
:
lambda
y
:
getattr
(
y
,
'__neg__'
)(),
'abs'
:
lambda
y
:
getattr
(
y
,
'__abs__'
)(),
'real'
:
lambda
y
:
getattr
(
y
,
'real'
),
'imag'
:
lambda
y
:
getattr
(
y
,
'imag'
),
'nanmin'
:
lambda
y
:
getattr
(
y
,
'nanmin'
)(
axis
=
axis
),
'amin'
:
lambda
y
:
getattr
(
y
,
'amin'
)(
axis
=
axis
),
'nanmax'
:
lambda
y
:
getattr
(
y
,
'nanmax'
)(
axis
=
axis
),
'amax'
:
lambda
y
:
getattr
(
y
,
'amax'
)(
axis
=
axis
),
'median'
:
lambda
y
:
getattr
(
y
,
'median'
)(
axis
=
axis
),
'mean'
:
lambda
y
:
getattr
(
y
,
'mean'
)(
axis
=
axis
),
'std'
:
lambda
y
:
getattr
(
y
,
'std'
)(
axis
=
axis
),
'var'
:
lambda
y
:
getattr
(
y
,
'var'
)(
axis
=
axis
),
'argmin_nonflat'
:
lambda
y
:
getattr
(
y
,
'argmin_nonflat'
)(
axis
=
axis
),
'argmin'
:
lambda
y
:
getattr
(
y
,
'argmin'
)(
axis
=
axis
),
'argmax_nonflat'
:
lambda
y
:
getattr
(
y
,
'argmax_nonflat'
)(
axis
=
axis
),
'argmax'
:
lambda
y
:
getattr
(
y
,
'argmax'
)(
axis
=
axis
),
'conjugate'
:
lambda
y
:
getattr
(
y
,
'conjugate'
)(),
'sum'
:
lambda
y
:
getattr
(
y
,
'sum'
)(
axis
=
axis
),
'prod'
:
lambda
y
:
getattr
(
y
,
'prod'
)(
axis
=
axis
),
'unique'
:
lambda
y
:
getattr
(
y
,
'unique'
)(),
'copy'
:
lambda
y
:
getattr
(
y
,
'copy'
)(),
'copy_empty'
:
lambda
y
:
getattr
(
y
,
'copy_empty'
)(),
'isnan'
:
lambda
y
:
getattr
(
y
,
'isnan'
)(),
'isinf'
:
lambda
y
:
getattr
(
y
,
'isinf'
)(),
'isfinite'
:
lambda
y
:
getattr
(
y
,
'isfinite'
)(),
'nan_to_num'
:
lambda
y
:
getattr
(
y
,
'nan_to_num'
)(),
'all'
:
lambda
y
:
getattr
(
y
,
'all'
)(
axis
=
axis
),
'any'
:
lambda
y
:
getattr
(
y
,
'any'
)(
axis
=
axis
),
'None'
:
lambda
y
:
y
}
return
translation
[
op
](
x
,
**
kwargs
)
def
binary_operation
(
self
,
x
,
y
,
op
=
'None'
,
cast
=
0
):
translation
=
{
'add'
:
lambda
z
:
getattr
(
z
,
'__add__'
),
'radd'
:
lambda
z
:
getattr
(
z
,
'__radd__'
),
'iadd'
:
lambda
z
:
getattr
(
z
,
'__iadd__'
),
'sub'
:
lambda
z
:
getattr
(
z
,
'__sub__'
),
'rsub'
:
lambda
z
:
getattr
(
z
,
'__rsub__'
),
'isub'
:
lambda
z
:
getattr
(
z
,
'__isub__'
),
'mul'
:
lambda
z
:
getattr
(
z
,
'__mul__'
),
'rmul'
:
lambda
z
:
getattr
(
z
,
'__rmul__'
),
'imul'
:
lambda
z
:
getattr
(
z
,
'__imul__'
),
'div'
:
lambda
z
:
getattr
(
z
,
'__div__'
),
'rdiv'
:
lambda
z
:
getattr
(
z
,
'__rdiv__'
),
'idiv'
:
lambda
z
:
getattr
(
z
,
'__idiv__'
),
'pow'
:
lambda
z
:
getattr
(
z
,
'__pow__'
),
'rpow'
:
lambda
z
:
getattr
(
z
,
'__rpow__'
),
'ipow'
:
lambda
z
:
getattr
(
z
,
'__ipow__'
),
'ne'
:
lambda
z
:
getattr
(
z
,
'__ne__'
),
'lt'
:
lambda
z
:
getattr
(
z
,
'__lt__'
),
'le'
:
lambda
z
:
getattr
(
z
,
'__le__'
),
'eq'
:
lambda
z
:
getattr
(
z
,
'__eq__'
),
'ge'
:
lambda
z
:
getattr
(
z
,
'__ge__'
),
'gt'
:
lambda
z
:
getattr
(
z
,
'__gt__'
),
'None'
:
lambda
z
:
lambda
u
:
u
}
if
(
cast
&
1
)
!=
0
:
x
=
self
.
cast
(
x
)
if
(
cast
&
2
)
!=
0
:
y
=
self
.
cast
(
y
)
return
translation
[
op
](
x
)(
y
)
def
__add__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'add'
)
__radd__
=
__add__
def
__iadd__
(
self
,
other
):
...
...
@@ -1269,6 +1340,7 @@ class field(object):
def
__mul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'mul'
)
__rmul__
=
__mul__
def
__imul__
(
self
,
other
):
...
...
@@ -1282,6 +1354,7 @@ class field(object):
def
__idiv__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'idiv'
,
inplace
=
True
)
__truediv__
=
__div__
__itruediv__
=
__idiv__
...
...
@@ -1318,6 +1391,7 @@ class field(object):
def
__gt__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'gt'
)
class
EmptyField
(
field
):
def
__init__
(
self
):
pass
test/test_nifty_field.py
View file @
5db4d734
...
...
@@ -123,3 +123,31 @@ class Test_field_init(unittest.TestCase):
f
=
field
(
domain
=
(
s
,),
dtype
=
s
.
dtype
,
datamodel
=
datamodel
)
assert
(
f
.
domain
[
0
]
is
s
)
assert
(
s
.
check_codomain
(
f
.
codomain
[
0
]))
assert
(
s
.
get_shape
()
==
f
.
get_shape
())
class
Test_field_multiple_init
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
([(
1
,)],
[
True
],
[
0
],
[
None
],
[
False
],
fft_modules
,
DATAMODELS
[
'rg_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_multiple_space_init
(
self
,
shape
,
zerocenter
,
complexity
,
distances
,
harmonic
,
fft_module
,
datamodel
):
s1
=
rg_space
(
shape
=
shape
,
zerocenter
=
zerocenter
,
complexity
=
complexity
,
distances
=
distances
,
harmonic
=
harmonic
,
fft_module
=
fft_module
)
s2
=
rg_space
(
shape
=
shape
,
zerocenter
=
zerocenter
,
complexity
=
complexity
,
distances
=
distances
,
harmonic
=
harmonic
,
fft_module
=
fft_module
)
f
=
field
(
domain
=
(
s1
,
s2
),
dtype
=
s1
.
dtype
,
datamodel
=
datamodel
)
assert
(
f
.
domain
[
0
]
is
s1
)
assert
(
f
.
domain
[
1
]
is
s2
)
assert
(
s1
.
check_codomain
(
f
.
codomain
[
0
]))
assert
(
s2
.
check_codomain
(
f
.
codomain
[
1
]))
assert
(
s1
.
get_shape
()
+
s2
.
get_shape
()
==
f
.
get_shape
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment