Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
0db9e4a6
Commit
0db9e4a6
authored
Nov 10, 2017
by
Martin Reinecke
Browse files
more fixes
parent
3f5859d0
Pipeline
#21362
failed with stage
in 3 minutes and 54 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/operators/power_projection_operator.py
View file @
0db9e4a6
...
...
@@ -50,12 +50,18 @@ class PowerProjectionOperator(LinearOperator):
tgt
[
self
.
_space
]
=
power_space
self
.
_target
=
DomainTuple
.
make
(
tgt
)
# shopping list:
# 1) make sure that pindex is distributed in the same way as in the Field living on self.domain.
# 2) if the operated-on space is not distributed (i.e. if it is not space 0), _no_ further communication is necessary
def
_times
(
self
,
x
):
# harmonic field goes in
# pindex must be distributed in the same way as harmonic field
# power field must be available in full
pindex
=
self
.
_target
[
self
.
_space
].
pindex
res
=
Field
.
zeros
(
self
.
_target
,
dtype
=
x
.
dtype
)
if
dobj
.
distaxis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
# the distributed axis is part of the projected space
pindex
=
dobj
.
local_data
(
pindex
)
else
:
else
:
# pindex must be available fully on every task
pindex
=
dobj
.
to_global_data
(
pindex
)
pindex
.
reshape
((
1
,
pindex
.
size
,
1
))
arr
=
dobj
.
local_data
(
x
.
weight
(
1
).
val
)
...
...
@@ -64,8 +70,15 @@ class PowerProjectionOperator(LinearOperator):
presize
=
np
.
prod
(
arr
.
shape
[
0
:
firstaxis
],
dtype
=
np
.
int
)
postsize
=
np
.
prod
(
arr
.
shape
[
lastaxis
+
1
:],
dtype
=
np
.
int
)
arr
=
arr
.
reshape
((
presize
,
pindex
.
size
,
postsize
))
oarr
=
dobj
.
local_data
(
res
.
val
).
reshape
((
presize
,
-
1
,
postsize
)
)
oarr
=
np
.
zeros
((
presize
,
self
.
_target
[
self
.
_space
].
shape
[
0
],
postsize
),
dtype
=
x
.
dtype
)
np
.
add
.
at
(
oarr
,
(
slice
(
None
),
pindex
.
ravel
(),
slice
(
None
)),
arr
)
if
dobj
.
distaxis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
oarr
=
dobj
.
np_allreduce_sum
(
oarr
)
oarr
=
oarr
.
reshape
(
self
.
_target
.
shape
)
res
=
Field
(
self
.
_target
,
dobj
.
from_global_data
(
oarr
))
else
:
oarr
=
oarr
.
reshape
(
dobj
.
get_locshape
(
self
.
_target
.
shape
,
dobj
.
distaxis
(
x
.
val
)))
res
=
Field
(
self
.
_target
,
dobj
.
from_local_data
(
self
.
_target
.
shape
,
oarr
,
dobj
.
default_distaxis
()))
return
res
.
weight
(
-
1
,
spaces
=
self
.
_space
)
def
_adjoint_times
(
self
,
x
):
...
...
@@ -73,10 +86,11 @@ class PowerProjectionOperator(LinearOperator):
res
=
Field
.
empty
(
self
.
_domain
,
dtype
=
x
.
dtype
)
if
dobj
.
distaxis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
# the distributed axis is part of the projected space
pindex
=
dobj
.
local_data
(
pindex
)
arr
=
dobj
.
to_global_data
(
x
.
val
)
else
:
pindex
=
dobj
.
to_global_data
(
pindex
)
arr
=
dobj
.
local_data
(
x
.
val
)
pindex
=
pindex
.
reshape
((
1
,
pindex
.
size
,
1
))
arr
=
dobj
.
local_data
(
x
.
val
)
firstaxis
=
x
.
domain
.
axes
[
self
.
_space
][
0
]
lastaxis
=
x
.
domain
.
axes
[
self
.
_space
][
-
1
]
presize
=
np
.
prod
(
arr
.
shape
[
0
:
firstaxis
],
dtype
=
np
.
int
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment