Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
a5a73894
Commit
a5a73894
authored
Jul 26, 2021
by
Neel Shah
Browse files
replace flatten() by ravel(), and minor changes
parent
6b285039
Pipeline
#106469
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/matrix_product_operator.py
View file @
a5a73894
...
...
@@ -61,7 +61,7 @@ class MatrixProductOperator(LinearOperator):
non-participating subspaces of the domain cannot be different in the
target space, thus the matrix must have enough unsummed axes to stand
in the places of summed-over axes of the domain, if those summed-over
axes are followed by any unsummed(inactive) axes. Example to make
axes are followed by any unsummed
(inactive) axes. Example to make
this clear: If the first 2 spaces of the domain are summed over in
the matrix multiplication and the 3rd space doesn't participate, the
matrix must have (at least) 2 axes that don't participate in the
...
...
@@ -105,9 +105,8 @@ class MatrixProductOperator(LinearOperator):
domain_dim
=
len
(
domain_shape
)
# take shortcut for trivial case
if
spaces
is
not
None
:
if
len
(
self
.
_domain
.
shape
)
==
1
and
spaces
==
(
0
,
):
spaces
=
None
if
spaces
is
not
None
and
len
(
self
.
_domain
.
shape
)
==
1
and
spaces
==
(
0
,
):
spaces
=
None
if
spaces
is
None
:
self
.
_spaces
=
None
...
...
@@ -157,7 +156,7 @@ class MatrixProductOperator(LinearOperator):
target_space_shape
.
append
(
self
.
_domain
[
i
].
shape
[
j
])
else
:
target_space_shape
.
append
(
matrix
.
shape
[
matrix_shape_idx
])
matrix_shape_idx
+=
1
matrix_shape_idx
+=
1
target_space_shape
=
tuple
(
target_space_shape
)
self
.
_mat_last_n
=
tuple
([
-
domain_dim
+
i
for
i
in
range
(
domain_dim
)])
...
...
@@ -216,14 +215,14 @@ class MatrixProductOperator(LinearOperator):
if
self
.
_spaces
is
None
:
if
not
self
.
_flatten
:
if
times
:
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
len
(
x
.
domain
.
shape
))
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
len
(
self
.
_
domain
.
shape
))
else
:
mat_axes
=
np
.
flip
(
self
.
_mat_last_m
)
field_axes
=
self
.
_target_axes
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
field_axes
))
res
=
res
.
transpose
()
else
:
res
=
m
.
dot
(
x
.
val
.
flatten
()).
reshape
(
self
.
_domain
.
shape
)
res
=
m
.
dot
(
x
.
val
.
ravel
()).
reshape
(
self
.
_domain
.
shape
)
return
Field
(
target
,
res
)
if
times
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment