Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
8648a6b4
Commit
8648a6b4
authored
Jul 14, 2018
by
Martin Reinecke
Browse files
generalize FieldZeroPadder
parent
af1d849c
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/field_zero_padder.py
View file @
8648a6b4
...
...
@@ -19,12 +19,11 @@ class FieldZeroPadder(LinearOperator):
dom
=
self
.
_domain
[
self
.
_space
]
if
not
isinstance
(
dom
,
RGSpace
):
raise
TypeError
(
"RGSpace required"
)
if
not
len
(
dom
.
shape
)
==
1
:
raise
TypeError
(
"RGSpace must be one-dimensional"
)
if
dom
.
harmonic
:
raise
TypeError
(
"RGSpace must not be harmonic"
)
tgt
=
RGSpace
((
int
(
factor
*
dom
.
shape
[
0
]),),
dom
.
distances
)
newshp
=
tuple
(
factor
*
s
for
s
in
dom
.
shape
)
tgt
=
RGSpace
(
newshp
,
dom
.
distances
)
self
.
_target
=
list
(
self
.
_domain
)
self
.
_target
[
self
.
_space
]
=
tgt
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
...
...
@@ -47,20 +46,21 @@ class FieldZeroPadder(LinearOperator):
dax
=
dobj
.
distaxis
(
x
)
shp_in
=
x
.
shape
shp_out
=
self
.
_tgt
(
mode
).
shape
ax
=
self
.
_target
.
axes
[
self
.
_space
][
0
]
if
dax
==
ax
:
x
=
dobj
.
redistribute
(
x
,
nodist
=
(
ax
,))
axbefore
=
self
.
_target
.
axes
[
self
.
_space
][
0
]
axes
=
self
.
_target
.
axes
[
self
.
_space
]
if
dax
in
axes
:
x
=
dobj
.
redistribute
(
x
,
nodist
=
axes
)
curax
=
dobj
.
distaxis
(
x
)
if
mode
==
self
.
ADJOINT_TIMES
:
newarr
=
np
.
empty
(
dobj
.
local_shape
(
shp_out
,
curax
),
dtype
=
x
.
dtype
)
newarr
[()]
=
dobj
.
local_data
(
x
)[(
slice
(
None
),)
*
ax
+
(
slice
(
0
,
shp_out
[
ax
]),)
]
sl
=
tuple
(
slice
(
0
,
shp_out
[
axis
])
for
axis
in
axes
)
newarr
[()]
=
dobj
.
local_data
(
x
)[(
slice
(
None
),)
*
axbefore
+
sl
]
else
:
newarr
=
np
.
zeros
(
dobj
.
local_shape
(
shp_out
,
curax
),
dtype
=
x
.
dtype
)
newarr
[(
slice
(
None
),)
*
ax
+
(
slice
(
0
,
shp_in
[
ax
]),)
]
=
dobj
.
local_data
(
x
)
sl
=
tuple
(
slice
(
0
,
shp_in
[
axis
])
for
axis
in
axes
)
newarr
[(
slice
(
None
),)
*
axbefore
+
sl
]
=
dobj
.
local_data
(
x
)
newarr
=
dobj
.
from_local_data
(
shp_out
,
newarr
,
distaxis
=
curax
)
if
dax
==
ax
:
newarr
=
dobj
.
redistribute
(
newarr
,
dist
=
ax
)
if
dax
in
ax
es
:
newarr
=
dobj
.
redistribute
(
newarr
,
dist
=
d
ax
)
return
Field
(
self
.
_tgt
(
mode
),
val
=
newarr
)
nifty5/operators/qht_operator.py
View file @
8648a6b4
...
...
@@ -80,7 +80,6 @@ class QHTOperator(LinearOperator):
n
=
self
.
_domain
.
axes
[
self
.
_space
]
rng
=
n
if
mode
==
self
.
TIMES
else
reversed
(
n
)
ax
=
dobj
.
distaxis
(
x
)
globshape
=
x
.
shape
for
i
in
rng
:
sl
=
(
slice
(
None
),)
*
i
+
(
slice
(
1
,
None
),)
if
i
==
ax
:
...
...
nifty5/operators/symmetrizing_operator.py
View file @
8648a6b4
...
...
@@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator):
self
.
_check_input
(
x
,
mode
)
tmp
=
x
.
val
.
copy
()
ax
=
dobj
.
distaxis
(
tmp
)
globshape
=
tmp
.
shape
for
i
in
self
.
_domain
.
axes
[
self
.
_space
]:
lead
=
(
slice
(
None
),)
*
i
if
i
==
ax
:
...
...
test/test_operators/test_adjoint.py
View file @
8648a6b4
...
...
@@ -119,7 +119,7 @@ class Consistency_Tests(unittest.TestCase):
@
expand
(
product
([
0
,
2
],
[
2
,
2.7
],
[
np
.
float64
,
np
.
complex128
]))
def
testZeroPadder
(
self
,
space
,
factor
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
RGSpace
(
7
),
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
RGSpace
(
7
,
12
),
ift
.
HPSpace
(
4
))
op
=
ift
.
FieldZeroPadder
(
dom
,
factor
,
space
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment