Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
67e0951c
Commit
67e0951c
authored
Aug 06, 2021
by
Philipp Arras
Browse files
Use linear_transpose for adjoint of JaxLinearOperator
parent
00000714
Changes
2
Hide whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
67e0951c
...
...
@@ -130,10 +130,10 @@ def JaxLinearOperator(domain, target, func, domain_dtype):
from
..sugar
import
makeDomain
domain
=
makeDomain
(
domain
)
if
isinstance
(
domain
,
DomainTuple
):
inp
=
np
.
ones
(
domain
.
shape
,
domain_dtype
)
inp
=
np
.
empty
(
domain
.
shape
,
domain_dtype
)
else
:
inp
=
{
kk
:
np
.
ones
(
domain
[
kk
].
shape
,
domain_dtype
[
kk
])
for
kk
in
domain
.
keys
()}
func_transposed
=
jax
.
jit
(
jax
.
vjp
(
func
,
inp
)
[
1
]
)
inp
=
{
kk
:
np
.
empty
(
domain
[
kk
].
shape
,
domain_dtype
[
kk
])
for
kk
in
domain
.
keys
()}
func_transposed
=
jax
.
jit
(
jax
.
linear_transpose
(
func
,
inp
))
return
_JaxLinearOperator
(
domain
,
target
,
func
,
func_transposed
)
...
...
test/test_operators/test_jax.py
View file @
67e0951c
...
...
@@ -46,7 +46,7 @@ def test_jax(dom, func):
if
linear
:
ift
.
extra
.
check_linear_operator
(
op
)
else
:
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
Exception
):
ift
.
extra
.
check_linear_operator
(
op
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment