Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
397f5fc7
Commit
397f5fc7
authored
Jul 19, 2018
by
Julia Stadler
Browse files
added central_zero_padder.py
parent
1cd93ef6
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/central_zero_padder.py
0 → 100644
View file @
397f5fc7
import
numpy
as
np
import
itertools
from
..
import
utilities
from
.linear_operator
import
LinearOperator
from
..domain_tuple
import
DomainTuple
from
..domains.rg_space
import
RGSpace
from
..field
import
Field
class
CentralZeroPadder
(
LinearOperator
):
def
__init__
(
self
,
domain
,
new_shape
,
space
=
0
):
super
(
CentralZeroPadder
,
self
).
__init__
()
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
dom
=
self
.
_domain
[
self
.
_space
]
if
not
isinstance
(
dom
,
RGSpace
):
raise
TypeError
(
"RGSpace required"
)
if
dom
.
harmonic
:
raise
TypeError
(
"RGSpace must not be harmonic"
)
if
len
(
new_shape
)
!=
len
(
dom
.
shape
):
raise
ValueError
(
"Shape missmatch"
)
if
any
(
[
a
<
b
for
a
,
b
in
zip
(
new_shape
,
dom
.
shape
)]):
raise
ValueError
(
"New shape must be larger than old shape"
)
tgt
=
RGSpace
(
new_shape
,
dom
.
distances
)
self
.
_target
=
list
(
self
.
_domain
)
self
.
_target
[
self
.
_space
]
=
tgt
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
slicer
=
[]
axes
=
self
.
_target
.
axes
[
self
.
_space
]
for
i
in
range
(
len
(
self
.
_domain
.
shape
)):
if
i
in
axes
:
slicer_fw
=
slice
(
0
,
self
.
_domain
.
shape
[
i
]
/
2
)
slicer_bw
=
slice
(
-
self
.
_domain
.
shape
[
i
]
/
2
,
None
)
slicer
.
append
(
[
slicer_fw
,
slicer_bw
]
)
self
.
slicer
=
list
(
itertools
.
product
(
*
slicer
))
for
i
in
range
(
len
(
self
.
slicer
)):
for
j
in
range
(
len
(
self
.
_domain
.
shape
)):
if
not
j
in
axes
:
tmp
=
(
list
(
self
.
slicer
[
i
]))
tmp
.
insert
(
j
,
slice
(
None
))
self
.
slicer
[
i
]
=
tmp
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_target
@
property
def
capability
(
self
):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
x
=
x
.
val
if
mode
==
self
.
TIMES
:
y
=
np
.
zeros
(
self
.
_target
.
shape
)
for
i
in
self
.
slicer
:
y
[
i
]
=
x
[
i
]
return
Field
(
self
.
_target
,
val
=
y
)
if
mode
==
self
.
ADJOINT_TIMES
:
y
=
np
.
zeros
(
self
.
_domain
.
shape
)
for
i
in
self
.
slicer
:
y
[
i
]
=
x
[
i
]
return
Field
(
self
.
_domain
,
val
=
y
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment