Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
f9ce2a6e
Commit
f9ce2a6e
authored
Sep 02, 2017
by
Martin Reinecke
Browse files
great Field revamp, next part
parent
77ee7dcc
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
f9ce2a6e
...
...
@@ -112,15 +112,13 @@ class Field(object):
"DomainObject instance."
)
return
domain
def
_get_axes_tuple
(
self
,
things_with_shape
,
start
=
0
):
i
=
start
def
_get_axes_tuple
(
self
,
things_with_shape
):
i
=
0
axes_list
=
[]
for
thing
in
things_with_shape
:
l
=
[]
for
j
in
range
(
len
(
thing
.
shape
)):
l
+=
[
i
]
i
+=
1
axes_list
+=
[
tuple
(
l
)]
nax
=
len
(
thing
.
shape
)
axes_list
+=
[
tuple
(
range
(
i
,
i
+
nax
))]
i
+=
nax
return
tuple
(
axes_list
)
def
_infer_dtype
(
self
,
dtype
,
val
):
...
...
@@ -179,7 +177,7 @@ class Field(object):
sample
=
f
.
get_val
(
copy
=
False
)
generator_function
=
getattr
(
Random
,
random_type
)
sample
[
:
]
=
generator_function
(
dtype
=
f
.
dtype
,
sample
[
()
]
=
generator_function
(
dtype
=
f
.
dtype
,
shape
=
sample
.
shape
,
**
kwargs
)
return
f
...
...
@@ -344,7 +342,7 @@ class Field(object):
local_data
=
pindex
semiscaled_local_data
=
local_data
.
reshape
(
semiscaled_local_shape
)
result_obj
=
np
.
empty
(
target_shape
,
dtype
=
pindex
.
dtype
)
result_obj
[
:
]
=
semiscaled_local_data
result_obj
[
()
]
=
semiscaled_local_data
return
result_obj
...
...
@@ -494,23 +492,13 @@ class Field(object):
def
_spec_to_rescaler
(
self
,
spec
,
result_list
,
power_space_index
):
power_space
=
self
.
domain
[
power_space_index
]
# weight the random fields with the power spectrum
# therefore get the pindex from the power space
pindex
=
power_space
.
pindex
# Now use numpy advanced indexing in order to put the entries of the
# power spectrum into the appropriate places of the pindex array.
# Do this for every 'pindex-slice' in parallel using the 'slice(None)'s
local_pindex
=
pindex
local_blow_up
=
[
slice
(
None
)]
*
len
(
spec
.
shape
)
# it is important to count from behind, since spec potentially grows
# with every iteration
index
=
self
.
domain_axes
[
power_space_index
][
0
]
-
len
(
self
.
shape
)
local_blow_up
[
index
]
=
local_
pindex
local_blow_up
[
index
]
=
power_space
.
pindex
# here, the power_spectrum is distributed into the new shape
local_rescaler
=
spec
[
local_blow_up
]
return
local_rescaler
return
spec
[
local_blow_up
]
# ---Properties---
...
...
@@ -668,19 +656,13 @@ class Field(object):
def
real
(
self
):
""" The real part of the field (data is not copied).
"""
real_part
=
self
.
val
.
real
result
=
self
.
copy_empty
(
dtype
=
real_part
.
dtype
)
result
.
set_val
(
new_val
=
real_part
,
copy
=
False
)
return
result
return
Field
(
self
.
domain
,
self
.
val
.
real
)
@
property
def
imag
(
self
):
""" The imaginary part of the field (data is not copied).
"""
real_part
=
self
.
val
.
imag
result
=
self
.
copy_empty
(
dtype
=
real_part
.
dtype
)
result
.
set_val
(
new_val
=
real_part
,
copy
=
False
)
return
result
return
Field
(
self
.
domain
,
self
.
val
.
imag
)
# ---Special unary/binary operations---
...
...
@@ -711,12 +693,9 @@ class Field(object):
"""
copied_val
=
self
.
get_val
(
copy
=
True
)
new_field
=
self
.
copy_empty
(
domain
=
domain
,
dtype
=
dtype
)
new_field
.
set_val
(
new_val
=
copied_val
,
copy
=
False
)
return
new_field
if
domain
is
None
:
domain
=
self
.
domain
return
Field
(
domain
=
domain
,
val
=
self
.
_val
,
dtype
=
dtype
,
copy
=
True
)
def
copy_empty
(
self
,
domain
=
None
,
dtype
=
None
):
""" Returns an empty copy of the Field.
...
...
@@ -748,41 +727,9 @@ class Field(object):
if
domain
is
None
:
domain
=
self
.
domain
else
:
domain
=
self
.
_parse_domain
(
domain
)
if
dtype
is
None
:
dtype
=
self
.
dtype
else
:
dtype
=
np
.
dtype
(
dtype
)
fast_copyable
=
True
try
:
for
i
in
range
(
len
(
self
.
domain
)):
if
self
.
domain
[
i
]
is
not
domain
[
i
]:
fast_copyable
=
False
break
except
IndexError
:
fast_copyable
=
False
if
(
fast_copyable
and
dtype
==
self
.
dtype
):
new_field
=
self
.
_fast_copy_empty
()
else
:
new_field
=
Field
(
domain
=
domain
,
dtype
=
dtype
)
return
new_field
def
_fast_copy_empty
(
self
):
# make an empty field
new_field
=
EmptyField
()
# repair its class
new_field
.
__class__
=
self
.
__class__
# copy domain, codomain and val
for
key
,
value
in
list
(
self
.
__dict__
.
items
()):
if
key
!=
'_val'
:
new_field
.
__dict__
[
key
]
=
value
else
:
new_field
.
__dict__
[
key
]
=
np
.
empty_like
(
self
.
val
)
return
new_field
return
Field
(
domain
=
domain
,
dtype
=
dtype
)
def
weight
(
self
,
power
=
1
,
inplace
=
False
,
spaces
=
None
):
""" Weights the pixels of `self` with their invidual pixel-volume.
...
...
@@ -805,10 +752,7 @@ class Field(object):
The weighted field.
"""
if
inplace
:
new_field
=
self
else
:
new_field
=
self
.
copy_empty
()
new_field
=
self
if
inplace
else
self
.
copy_empty
()
new_val
=
self
.
get_val
(
copy
=
False
)
...
...
@@ -851,16 +795,10 @@ class Field(object):
"the NIFTy field class"
)
# Compute the dot respecting the fact of discrete/continuous spaces
if
bare
:
y
=
self
else
:
y
=
self
.
weight
(
power
=
1
)
y
=
self
if
bare
else
self
.
weight
(
power
=
1
)
if
spaces
is
None
:
x_val
=
x
.
get_val
(
copy
=
False
)
y_val
=
y
.
get_val
(
copy
=
False
)
result
=
(
y_val
.
conjugate
()
*
x_val
).
sum
()
return
result
return
np
.
vdot
(
y
.
val
.
flatten
(),
x
.
val
.
flatten
())
else
:
# create a diagonal operator which is capable of taking care of the
# axes-matching
...
...
@@ -899,57 +837,27 @@ class Field(object):
"""
if
inplace
:
work_field
=
self
self
.
imag
*=-
1
return
self
else
:
work_field
=
self
.
copy_empty
()
new_val
=
self
.
get_val
(
copy
=
False
)
new_val
=
new_val
.
conjugate
()
work_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
work_field
return
Field
(
self
.
domain
,
np
.
conj
(
self
.
val
),
self
.
dtype
)
# ---General unary/contraction methods---
def
__pos__
(
self
):
""" x.__pos__() <==> +x
Returns a (positive) copy of `self`.
"""
return
self
.
copy
()
def
__neg__
(
self
):
""" x.__neg__() <==> -x
Returns a negative copy of `self`.
"""
return_field
=
self
.
copy_empty
()
new_val
=
-
self
.
get_val
(
copy
=
False
)
return_field
.
set_val
(
new_val
,
copy
=
False
)
return
return_field
return
Field
(
self
.
domain
,
-
self
.
val
,
self
.
dtype
)
def
__abs__
(
self
):
""" x.__abs__() <==> abs(x)
Returns an absolute valued copy of `self`.
"""
new_val
=
abs
(
self
.
get_val
(
copy
=
False
))
return_field
=
self
.
copy_empty
(
dtype
=
new_val
.
dtype
)
return_field
.
set_val
(
new_val
,
copy
=
False
)
return
return_field
return
Field
(
self
.
domain
,
np
.
abs
(
self
.
val
),
self
.
dtype
)
def
_contraction_helper
(
self
,
op
,
spaces
):
# build a list of all axes
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
)
)
else
:
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
return
getattr
(
self
.
val
,
op
)(
)
# build a list of all axes
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
axes_list
=
tuple
(
self
.
domain_axes
[
sp_index
]
for
sp_index
in
spaces
)
...
...
@@ -1010,28 +918,14 @@ class Field(object):
# ---General binary methods---
def
_binary_helper
(
self
,
other
,
op
,
inplace
=
False
):
def
_binary_helper
(
self
,
other
,
op
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
try
:
assert
len
(
other
.
domain
)
==
len
(
self
.
domain
)
for
index
in
range
(
len
(
self
.
domain
)):
assert
other
.
domain
[
index
]
==
self
.
domain
[
index
]
except
AssertionError
:
raise
ValueError
(
"domains are incompatible."
)
other
=
other
.
get_val
(
copy
=
False
)
if
other
.
domain
!=
self
.
domain
:
raise
ValueError
(
"domains are incompatible."
)
return
Field
(
self
.
domain
,
getattr
(
self
.
val
,
op
)(
other
.
val
))
self_val
=
self
.
get_val
(
copy
=
False
)
return_val
=
getattr
(
self_val
,
op
)(
other
)
if
inplace
:
working_field
=
self
else
:
working_field
=
self
.
copy_empty
(
dtype
=
return_val
.
dtype
)
working_field
.
set_val
(
return_val
,
copy
=
False
)
return
working_field
return
Field
(
self
.
domain
,
getattr
(
self
.
val
,
op
)(
other
))
def
__add__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__add__'
)
...
...
@@ -1040,7 +934,7 @@ class Field(object):
return
self
.
_binary_helper
(
other
,
op
=
'__radd__'
)
def
__iadd__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__iadd__'
,
inplace
=
True
)
return
self
.
_binary_helper
(
other
,
op
=
'__iadd__'
)
def
__sub__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__sub__'
)
...
...
@@ -1049,7 +943,7 @@ class Field(object):
return
self
.
_binary_helper
(
other
,
op
=
'__rsub__'
)
def
__isub__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__isub__'
,
inplace
=
True
)
return
self
.
_binary_helper
(
other
,
op
=
'__isub__'
)
def
__mul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__mul__'
)
...
...
@@ -1058,7 +952,7 @@ class Field(object):
return
self
.
_binary_helper
(
other
,
op
=
'__rmul__'
)
def
__imul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__imul__'
,
inplace
=
True
)
return
self
.
_binary_helper
(
other
,
op
=
'__imul__'
)
def
__div__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__div__'
)
...
...
@@ -1073,7 +967,7 @@ class Field(object):
return
self
.
_binary_helper
(
other
,
op
=
'__rtruediv__'
)
def
__idiv__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__idiv__'
,
inplace
=
True
)
return
self
.
_binary_helper
(
other
,
op
=
'__idiv__'
)
def
__pow__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__pow__'
)
...
...
@@ -1082,7 +976,7 @@ class Field(object):
return
self
.
_binary_helper
(
other
,
op
=
'__rpow__'
)
def
__ipow__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__ipow__'
,
inplace
=
True
)
return
self
.
_binary_helper
(
other
,
op
=
'__ipow__'
)
def
__lt__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__lt__'
)
...
...
@@ -1119,8 +1013,3 @@ class Field(object):
"
\n
- val = "
+
repr
(
self
.
get_val
())
+
\
"
\n
- min.,max. = "
+
str
(
minmax
)
+
\
"
\n
- mean = "
+
str
(
mean
)
class
EmptyField
(
Field
):
def
__init__
(
self
):
pass
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment