Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
4424bdee
Commit
4424bdee
authored
Jul 23, 2016
by
theos
Browse files
Added inplace parameter to Space.weight.
parent
39b1cc50
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
4424bdee
...
...
@@ -220,7 +220,10 @@ class Field(object):
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
_infer_dtype
(
self
,
domain
=
None
,
dtype
=
None
,
field_type
=
None
):
dtype_tuple
=
(
np
.
dtype
(
gc
[
'default_field_dtype'
]),)
if
dtype
is
None
:
dtype_tuple
=
(
np
.
dtype
(
gc
[
'default_field_dtype'
]),)
else
:
dtype_tuple
=
(
np
.
dtype
(
dtype
))
if
domain
is
not
None
:
dtype_tuple
+=
tuple
(
np
.
dtype
(
sp
.
dtype
)
for
sp
in
domain
)
if
field_type
is
not
None
:
...
...
@@ -331,6 +334,7 @@ class Field(object):
def
copy
(
self
,
domain
=
None
,
codomain
=
None
,
field_type
=
None
,
**
kwargs
):
copied_val
=
self
.
_unary_operation
(
self
.
get_val
(),
op
=
'copy'
,
**
kwargs
)
# TODO: respect distribution_strategy
new_field
=
self
.
copy_empty
(
domain
=
domain
,
codomain
=
codomain
,
field_type
=
field_type
)
...
...
@@ -391,6 +395,7 @@ class Field(object):
**
kwargs
)
return
new_field
# TODO: use property for val
def
set_val
(
self
,
new_val
=
None
,
copy
=
False
):
"""
Resets the field values.
...
...
@@ -431,6 +436,7 @@ class Field(object):
return
global_shape
# use space.dim and field_type.dim
@
property
def
dim
(
self
):
"""
...
...
@@ -512,6 +518,7 @@ class Field(object):
shape
=
self
.
shape
# Case 1: x is a distributed_data_object
# TODO: Use d2o casting for this case directly, too.
if
isinstance
(
x
,
distributed_data_object
):
if
x
.
comm
is
not
self
.
_comm
:
raise
ValueError
(
about
.
_errors
.
cstring
(
...
...
@@ -608,9 +615,10 @@ class Field(object):
spaces
=
range
(
len
(
self
.
shape
))
for
ind
,
sp
in
enumerate
(
self
.
domain
):
new_val
=
sp
.
calc_weight
(
new_val
,
power
=
power
,
axes
=
self
.
domain_axes
[
ind
])
new_val
=
sp
.
weight
(
new_val
,
power
=
power
,
axes
=
self
.
domain_axes
[
ind
],
inplace
=
inplace
)
new_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
new_field
...
...
@@ -1164,6 +1172,7 @@ class Field(object):
return
self
.
_unary_operation
(
self
.
get_val
(),
op
=
'var'
,
**
kwargs
)
# TODO: replace `split` by `def argmin_nonflat`
def
argmin
(
self
,
split
=
False
,
**
kwargs
):
"""
Returns the index of the minimum field value.
...
...
@@ -1348,7 +1357,8 @@ class Field(object):
def
__add__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'add'
)
__radd__
=
__add__
def
__radd__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'radd'
)
def
__iadd__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'iadd'
,
inplace
=
True
)
...
...
@@ -1365,7 +1375,8 @@ class Field(object):
def
__mul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'mul'
)
__rmul__
=
__mul__
def
__rmul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'rmul'
)
def
__imul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'imul'
,
inplace
=
True
)
...
...
@@ -1379,9 +1390,6 @@ class Field(object):
def
__idiv__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'idiv'
,
inplace
=
True
)
__truediv__
=
__div__
__itruediv__
=
__idiv__
def
__pow__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'pow'
)
...
...
nifty/spaces/power_space/power_space.py
View file @
4424bdee
...
...
@@ -35,7 +35,7 @@ class PowerSpace(Space):
# every power-pixel has a volume of 1
return
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
paradict
[
'pindex'
].
shape
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
):
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
total_shape
=
x
.
shape
axes
=
cast_axis_to_tuple
(
axes
,
len
(
total_shape
))
...
...
@@ -49,7 +49,12 @@ class PowerSpace(Space):
weight
=
self
.
paradict
[
'rho'
].
reshape
(
reshaper
)
if
power
!=
1
:
weight
=
weight
**
power
result_x
=
x
*
weight
if
inplace
:
x
*=
weight
result_x
=
x
else
:
result_x
=
x
*
weight
return
result_x
...
...
nifty/spaces/rg_space/rg_space.py
View file @
4424bdee
...
...
@@ -175,9 +175,14 @@ class RGSpace(Space):
def
total_volume
(
self
):
return
self
.
dim
*
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
paradict
[
'distances'
])
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
):
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
weight
=
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
paradict
[
'distances'
])
**
power
return
x
*
weight
if
inplace
:
x
*=
weight
result_x
=
x
else
:
result_x
=
x
*
weight
return
result_x
def
compute_k_array
(
self
,
distribution_strategy
):
"""
...
...
nifty/spaces/space/space.py
View file @
4424bdee
...
...
@@ -262,7 +262,7 @@ class Space(object):
def
complement_cast
(
self
,
x
,
axes
=
None
):
return
x
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
):
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
"""
Weights a given array of field values with the pixel volumes (not
the meta volumes) to a given power.
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment