Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
ift
NIFTy
Commits
c56d29a4
Commit
c56d29a4
authored
Mar 26, 2020
by
Lukas Platz
Browse files
extend MatrixProductOperator for multi-dim fields
parent
193a276f
Pipeline
#71502
passed with stages
in 17 minutes and 6 seconds
Changes
2
Pipelines
1
Show whitespace changes
Inline
Side-by-side
nifty6/operators/simple_linear_operators.py
View file @
c56d29a4
...
...
@@ -22,6 +22,7 @@ from ..multi_domain import MultiDomain
from
..multi_field
import
MultiField
from
.endomorphic_operator
import
EndomorphicOperator
from
.linear_operator
import
LinearOperator
import
numpy
as
np
class
VdotOperator
(
LinearOperator
):
...
...
@@ -360,21 +361,41 @@ class MatrixProductOperator(EndomorphicOperator):
matrix: scipy.sparse matrix or numpy array
Matrix of shape `(domain.shape, domain.shape)`. Needs to support
`dot()` and `transpose()` in the style of numpy arrays.
axis: integer or None
in case of multi-dim input fields (N > 1), along which axis
of the input field to apply the matrix
"""
def
__init__
(
self
,
domain
,
matrix
):
def
__init__
(
self
,
domain
,
matrix
,
axis
=
None
):
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_domain
=
DomainTuple
.
make
(
domain
)
shp
=
self
.
_domain
.
shape
if
len
(
shp
)
>
1
:
raise
TypeError
(
'Only 1D-domain supported yet.'
)
if
matrix
.
shape
!=
(
*
shp
,
*
shp
):
raise
ValueError
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
if
axis
is
None
:
raise
ValueError
(
"For multi-dim inputs an axis needs to be specified."
)
ref_shp
=
(
shp
[
axis
],
shp
[
axis
])
else
:
if
not
(
axis
is
None
or
axis
==
0
):
raise
ValueError
(
"For one-dim inputs axis must be None or zero"
)
ref_shp
=
(
shp
[
0
],
shp
[
0
])
axis
=
None
if
matrix
.
shape
!=
ref_shp
:
raise
ValueError
(
"Domain/domain on axis and matrix shape do not match."
)
self
.
_mat
=
matrix
self
.
_mat_tr
=
matrix
.
transpose
().
conjugate
()
self
.
_axis
=
axis
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
res
=
x
.
val
f
=
self
.
_mat
.
dot
if
mode
==
self
.
TIMES
else
self
.
_mat_tr
.
dot
res
=
f
(
res
)
m
=
self
.
_mat
if
mode
==
self
.
TIMES
else
self
.
_mat_tr
if
self
.
_axis
is
None
:
res
=
m
.
dot
(
x
.
val
)
else
:
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
-
1
,
self
.
_axis
))
res
=
np
.
moveaxis
(
res
,
0
,
self
.
_axis
)
return
Field
(
self
.
_domain
,
res
)
test/test_operators/test_adjoint.py
View file @
c56d29a4
...
...
@@ -280,7 +280,7 @@ def testSpecialSum(sp):
@
pmp
(
'sp'
,
[
ift
.
RGSpace
(
10
)])
@
pmp
(
'seed'
,
[
12
,
3
])
def
testMatrixProductOperator
(
sp
,
seed
):
def
testMatrixProductOperator
_1d
(
sp
,
seed
):
ift
.
random
.
push_sseq_from_seed
(
seed
)
mat
=
ift
.
random
.
current_rng
().
standard_normal
((
*
sp
.
shape
,
*
sp
.
shape
))
op
=
ift
.
MatrixProductOperator
(
sp
,
mat
)
...
...
@@ -291,6 +291,21 @@ def testMatrixProductOperator(sp, seed):
ift
.
random
.
pop_sseq
()
@
pmp
(
'sp'
,
[
ift
.
RGSpace
((
2
,
10
))])
@
pmp
(
'axis'
,
[
0
,
1
])
@
pmp
(
'seed'
,
[
12
,
3
])
def
testMatrixProductOperator_2d
(
sp
,
axis
,
seed
):
mat_shp
=
(
sp
.
shape
[
axis
],
sp
.
shape
[
axis
])
ift
.
random
.
push_sseq_from_seed
(
seed
)
mat
=
ift
.
random
.
current_rng
().
standard_normal
(
mat_shp
)
op
=
ift
.
MatrixProductOperator
(
sp
,
mat
,
axis
)
ift
.
extra
.
consistency_check
(
op
)
mat
=
mat
+
1j
*
ift
.
random
.
current_rng
().
standard_normal
(
mat_shp
)
op
=
ift
.
MatrixProductOperator
(
sp
,
mat
,
axis
)
ift
.
extra
.
consistency_check
(
op
)
ift
.
random
.
pop_sseq
()
@
pmp
(
'seed'
,
[
12
,
3
])
def
testPartialExtractor
(
seed
):
ift
.
random
.
push_sseq_from_seed
(
seed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment