Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
c2dd59ae
Commit
c2dd59ae
authored
Jun 29, 2021
by
Neel Shah
Browse files
Fixed bug with adjoint_times for spaces=None, minor optimization and aesthetic changes
parent
4457dfd7
Pipeline
#104575
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/general_matrix_product.py
View file @
c2dd59ae
...
...
@@ -4,6 +4,7 @@ from .. import utilities
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
..domains.rg_space
import
RGSpace
from
.linear_operator
import
LinearOperator
class
GeneralMatrixProduct
(
LinearOperator
):
...
...
@@ -19,7 +20,8 @@ class GeneralMatrixProduct(LinearOperator):
scipy.sparse matrices the `flatten` keyword argument must be
set to true. This means that the input field will be flattened
before applying the matrix and reshaped to its original shape
afterwards.
afterwards. Flattening is only supported when the domain and target
are the same, and a target can't be specified if flatten=True'
Matrices are tested regarding their compatibility with the
called for application method.
...
...
@@ -80,13 +82,16 @@ class GeneralMatrixProduct(LinearOperator):
self
.
_spaces
=
None
self
.
_active_axes
=
utilities
.
my_sum
(
self
.
_domain
.
axes
)
self
.
_inactive_axes
=
()
if
flatten
:
domain_shape
=
(
utilities
.
my_product
(
domain_shape
),
)
target_space_shape
=
domain_shape
target_dim
=
len
(
target_space_shape
)
mat_inactive_axes_dim
=
mat_dim
-
len
(
domain_shape
)
if
mat_inactive_axes_dim
<
0
:
raise
ValueError
(
'
Domain
too big for
matrix.
'
)
raise
ValueError
(
"
Domain
has more dimensions than
matrix.
"
)
target_space_shape
=
matrix
.
shape
[:
mat_inactive_axes_dim
]
target_dim
=
mat_inactive_axes_dim
if
flatten
:
target_space_shape
=
(
utilities
.
my_product
(
target_space_shape
),
)
else
:
if
flatten
:
raise
ValueError
(
...
...
@@ -108,9 +113,9 @@ class GeneralMatrixProduct(LinearOperator):
domain_shape
=
tuple
(
domain_shape
)
self
.
_active_axes
=
tuple
(
active_axes
)
self
.
_inactive_axes
=
tuple
(
self
.
_inactive_axes
)
mat_inactive_axes_dim
=
len
(
matrix
.
shape
)
-
len
(
domain_shape
)
mat_inactive_axes_dim
=
mat_dim
-
len
(
domain_shape
)
if
mat_inactive_axes_dim
<
0
:
raise
ValueError
(
'
Domain
too big for
matrix.
'
)
raise
ValueError
(
"
Domain
has more dimensions than
matrix.
"
)
target_dim
=
mat_inactive_axes_dim
+
len
(
self
.
_inactive_axes
)
domain_dim
=
len
(
domain_shape
)
...
...
@@ -127,23 +132,44 @@ class GeneralMatrixProduct(LinearOperator):
self
.
_mat_last_n
=
tuple
([
-
domain_dim
+
i
for
i
in
range
(
domain_dim
)])
self
.
_mat_first_n
=
np
.
arange
(
domain_dim
)
self
.
_target_last_n
=
tuple
([
-
len
(
self
.
_inactive_axes
)
+
i
for
i
in
range
(
len
(
self
.
_inactive_axes
))])
#mat_last_m is needed for adjoint application even if spaces = None
self
.
_mat_last_m
=
tuple
([
-
mat_inactive_axes_dim
+
i
for
i
in
range
(
mat_inactive_axes_dim
)])
self
.
_target_last_n
=
tuple
([
-
len
(
self
.
_inactive_axes
)
+
i
for
i
in
range
(
len
(
self
.
_inactive_axes
))])
# mat_last_m is needed for adjoint application even if spaces=None
self
.
_mat_last_m
=
tuple
([
-
mat_inactive_axes_dim
+
i
for
i
in
range
(
mat_inactive_axes_dim
)])
self
.
_target_axes
=
tuple
(
range
(
len
(
target_space_shape
)))
if
spaces
!=
None
:
self
.
_field_axes
=
list
(
self
.
_target_axes
)
for
i
in
list
(
self
.
_target_axes
):
if
i
in
self
.
_inactive_axes
:
self
.
_field_axes
.
remove
(
i
)
self
.
_field_axes
=
tuple
(
self
.
_field_axes
)
if
target
==
None
:
if
target_dim
!=
0
:
default
_target
=
DomainTuple
.
make
(
RGSpace
(
shape
=
target_space_shape
))
if
flatten
:
self
.
_target
=
self
.
_domain
else
:
default_target
=
DomainTuple
.
make
(
None
)
self
.
_target
=
default_target
if
target_dim
!=
0
:
default_target
=
DomainTuple
.
make
(
RGSpace
(
shape
=
target_space_shape
))
else
:
default_target
=
DomainTuple
.
make
(
None
)
self
.
_target
=
default_target
elif
flatten
:
raise
ValueError
(
"Flattening is supported only for endomorphic application,"
+
" and you can't specify a target."
)
elif
target
.
shape
==
target_space_shape
:
self
.
_target
=
target
self
.
_target
=
DomainTuple
.
make
(
target
)
else
:
raise
ValueError
(
"Target space has invalid shape."
)
raise
ValueError
(
f
"Target space has invalid shape.
\n
"
+
"Its shape should be {target_space_shape}."
)
if
matrix
.
shape
[
mat_inactive_axes_dim
:]
!=
domain_shape
:
raise
ValueError
(
"Matrix doesn't fit with the domain."
)
matrix_appl_shape
=
matrix
.
shape
[
mat_inactive_axes_dim
:]
if
matrix_appl_shape
!=
domain_shape
:
raise
ValueError
(
"Matrix doesn't fit with the domain.
\n
"
+
f
"Shape of matrix axes used in summation:
{
matrix_appl_shape
}
,
\n
"
+
f
"Shape of domain axes used in summation:
{
domain_shape
}
."
)
self
.
_mat
=
matrix
self
.
_mat_tr
=
matrix
.
transpose
().
conjugate
()
...
...
@@ -158,29 +184,25 @@ class GeneralMatrixProduct(LinearOperator):
if
self
.
_spaces
is
None
:
if
not
self
.
_flatten
:
if
times
:
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
len
(
x
.
shape
))
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
len
(
x
.
domain
.
shape
))
else
:
mat_axes
=
np
.
flip
(
self
.
_mat_last_m
)
field_axes
=
list
(
range
(
len
(
self
.
_target
.
shape
)))
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
field_axes
))
res
=
res
.
reshape
(
np
.
flip
(
res
.
shape
)
)
field_axes
=
self
.
_target
_axes
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
field_axes
))
res
=
res
.
transpose
(
)
else
:
res
=
m
.
dot
(
x
.
val
.
flatten
()).
reshape
(
self
.
_domain
.
shape
)
return
Field
(
target
,
res
)
return
Field
(
target
,
res
)
if
times
:
mat_axes
=
self
.
_mat_last_n
move_axes
=
self
.
_target_last_n
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
self
.
_active_axes
))
res
=
np
.
moveaxis
(
res
,
move_axes
,
self
.
_inactive_axes
)
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
self
.
_active_axes
))
res
=
np
.
moveaxis
(
res
,
move_axes
,
self
.
_inactive_axes
)
else
:
mat_axes
=
np
.
flip
(
self
.
_mat_last_m
)
move_axes
=
np
.
flip
(
self
.
_mat_first_n
)
field_axes
=
list
(
range
(
len
(
self
.
_target
.
shape
)))
for
i
in
range
(
len
(
field_axes
)):
if
i
in
self
.
_inactive_axes
:
field_axes
.
remove
(
i
)
field_axes
=
tuple
(
field_axes
)
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
field_axes
))
res
=
np
.
moveaxis
(
res
,
move_axes
,
self
.
_active_axes
)
return
Field
(
target
,
res
)
field_axes
=
self
.
_field_axes
res
=
np
.
tensordot
(
m
,
x
.
val
,
axes
=
(
mat_axes
,
field_axes
))
res
=
np
.
moveaxis
(
res
,
move_axes
,
self
.
_active_axes
)
return
Field
(
target
,
res
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment