Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
fcbd1ea9
Commit
fcbd1ea9
authored
Jul 17, 2018
by
Martin Reinecke
Browse files
Merge branch 'even_more_operator_work' into 'NIFTy_5'
Even more operator work See merge request ift/nifty-dev!59
parents
7c9500fa
2db0d555
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/field_zero_padder.py
View file @
fcbd1ea9
...
...
@@ -12,19 +12,21 @@ from .. import utilities
class
FieldZeroPadder
(
LinearOperator
):
def
__init__
(
self
,
domain
,
factor
,
space
=
0
):
def
__init__
(
self
,
domain
,
new_shape
,
space
=
0
):
super
(
FieldZeroPadder
,
self
).
__init__
()
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_space
=
utilities
.
infer_space
(
self
.
_domain
,
space
)
dom
=
self
.
_domain
[
self
.
_space
]
if
not
isinstance
(
dom
,
RGSpace
):
raise
TypeError
(
"RGSpace required"
)
if
not
len
(
dom
.
shape
)
==
1
:
raise
TypeError
(
"RGSpace must be one-dimensional"
)
if
dom
.
harmonic
:
raise
TypeError
(
"RGSpace must not be harmonic"
)
tgt
=
RGSpace
((
int
(
factor
*
dom
.
shape
[
0
]),),
dom
.
distances
)
if
len
(
new_shape
)
!=
len
(
dom
.
shape
):
raise
ValueError
(
"Shape mismatch"
)
if
any
([
a
<
b
for
a
,
b
in
zip
(
new_shape
,
dom
.
shape
)]):
raise
ValueError
(
"New shape must be larger than old shape"
)
tgt
=
RGSpace
(
new_shape
,
dom
.
distances
)
self
.
_target
=
list
(
self
.
_domain
)
self
.
_target
[
self
.
_space
]
=
tgt
self
.
_target
=
DomainTuple
.
make
(
self
.
_target
)
...
...
@@ -47,20 +49,21 @@ class FieldZeroPadder(LinearOperator):
dax
=
dobj
.
distaxis
(
x
)
shp_in
=
x
.
shape
shp_out
=
self
.
_tgt
(
mode
).
shape
ax
=
self
.
_target
.
axes
[
self
.
_space
][
0
]
if
dax
==
ax
:
x
=
dobj
.
redistribute
(
x
,
nodist
=
(
ax
,))
axbefore
=
self
.
_target
.
axes
[
self
.
_space
][
0
]
axes
=
self
.
_target
.
axes
[
self
.
_space
]
if
dax
in
axes
:
x
=
dobj
.
redistribute
(
x
,
nodist
=
axes
)
curax
=
dobj
.
distaxis
(
x
)
if
mode
==
self
.
ADJOINT_TIMES
:
newarr
=
np
.
empty
(
dobj
.
local_shape
(
shp_out
,
curax
),
dtype
=
x
.
dtype
)
newarr
[()]
=
dobj
.
local_data
(
x
)[(
slice
(
None
),)
*
ax
+
(
slice
(
0
,
shp_out
[
ax
]),)
]
sl
=
tuple
(
slice
(
0
,
shp_out
[
axis
])
for
axis
in
axes
)
newarr
[()]
=
dobj
.
local_data
(
x
)[(
slice
(
None
),)
*
axbefore
+
sl
]
else
:
newarr
=
np
.
zeros
(
dobj
.
local_shape
(
shp_out
,
curax
),
dtype
=
x
.
dtype
)
newarr
[(
slice
(
None
),)
*
ax
+
(
slice
(
0
,
shp_in
[
ax
]),)
]
=
dobj
.
local_data
(
x
)
sl
=
tuple
(
slice
(
0
,
shp_in
[
axis
])
for
axis
in
axes
)
newarr
[(
slice
(
None
),)
*
axbefore
+
sl
]
=
dobj
.
local_data
(
x
)
newarr
=
dobj
.
from_local_data
(
shp_out
,
newarr
,
distaxis
=
curax
)
if
dax
==
ax
:
newarr
=
dobj
.
redistribute
(
newarr
,
dist
=
ax
)
if
dax
in
ax
es
:
newarr
=
dobj
.
redistribute
(
newarr
,
dist
=
d
ax
)
return
Field
(
self
.
_tgt
(
mode
),
val
=
newarr
)
nifty5/operators/qht_operator.py
View file @
fcbd1ea9
...
...
@@ -74,7 +74,6 @@ class QHTOperator(LinearOperator):
n
=
self
.
_domain
.
axes
[
self
.
_space
]
rng
=
n
if
mode
==
self
.
TIMES
else
reversed
(
n
)
ax
=
dobj
.
distaxis
(
x
)
globshape
=
x
.
shape
for
i
in
rng
:
sl
=
(
slice
(
None
),)
*
i
+
(
slice
(
1
,
None
),)
if
i
==
ax
:
...
...
nifty5/operators/symmetrizing_operator.py
View file @
fcbd1ea9
...
...
@@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator):
self
.
_check_input
(
x
,
mode
)
tmp
=
x
.
val
.
copy
()
ax
=
dobj
.
distaxis
(
tmp
)
globshape
=
tmp
.
shape
for
i
in
self
.
_domain
.
axes
[
self
.
_space
]:
lead
=
(
slice
(
None
),)
*
i
if
i
==
ax
:
...
...
test/test_operators/test_adjoint.py
View file @
fcbd1ea9
...
...
@@ -119,9 +119,10 @@ class Consistency_Tests(unittest.TestCase):
@
expand
(
product
([
0
,
2
],
[
2
,
2.7
],
[
np
.
float64
,
np
.
complex128
]))
def
testZeroPadder
(
self
,
space
,
factor
,
dtype
):
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
RGSpace
(
7
),
dom
=
(
ift
.
RGSpace
(
10
),
ift
.
UnstructuredDomain
(
13
),
ift
.
RGSpace
(
7
,
12
),
ift
.
HPSpace
(
4
))
op
=
ift
.
FieldZeroPadder
(
dom
,
factor
,
space
)
newshape
=
[
factor
*
l
for
l
in
dom
[
space
].
shape
]
op
=
ift
.
FieldZeroPadder
(
dom
,
newshape
,
space
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([(
ift
.
RGSpace
(
10
,
harmonic
=
True
),
4
,
0
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment