Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
404027de
Commit
404027de
authored
Jul 12, 2021
by
Philipp Arras
Browse files
JaxOperator: Add checks for output for debugging
parent
2db6c449
Changes
2
Show whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
404027de
...
...
@@ -61,13 +61,38 @@ class JaxOperator(Operator):
def
apply
(
self
,
x
):
from
..sugar
import
is_linearization
,
makeField
from
..multi_domain
import
MultiDomain
self
.
_check_input
(
x
)
if
is_linearization
(
x
):
res
,
bwd
=
self
.
_vjp
(
x
.
val
.
val
)
fwd
=
lambda
y
:
self
.
_fwd
(
x
.
val
.
val
,
y
)
jac
=
_JaxJacobian
(
self
.
_domain
,
self
.
_target
,
fwd
,
bwd
)
return
x
.
new
(
makeField
(
self
.
_target
,
_jax2np
(
res
)),
jac
)
return
makeField
(
self
.
_target
,
_jax2np
(
self
.
_func
(
x
.
val
)))
res
=
_jax2np
(
self
.
_func
(
x
.
val
))
if
isinstance
(
res
,
dict
):
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
raise
TypeError
((
"Jax function return a dictionary although the "
"target of the operator is a DomainTuple."
))
if
set
(
res
.
keys
())
!=
set
(
self
.
_target
.
keys
()):
raise
ValueError
((
"Keys do not match:
\n
"
f
"Target keys:
{
self
.
_target
.
keys
()
}
\n
"
f
"Jax function returns:
{
res
.
keys
()
}
"
))
for
kk
in
res
.
keys
():
self
.
_check_shape
(
self
.
_target
[
kk
].
shape
,
res
[
kk
].
shape
)
else
:
if
isinstance
(
self
.
_target
,
MultiDomain
):
raise
TypeError
((
"Jax function does not return a dictionary "
"although the target of the operator is a "
"MultiDomain."
))
self
.
_check_shape
(
self
.
_target
.
shape
,
res
.
shape
)
return
makeField
(
self
.
_target
,
res
)
@
staticmethod
def
_check_shape
(
shp_tgt
,
shp_jax
):
if
shp_tgt
!=
shp_jax
:
raise
ValueError
((
"Output shapes do not match:
\n
"
f
"Target shape is
\t\t
{
shp_tgt
}
\n
"
f
"Jax function returns
\t
{
shp_jax
}
"
))
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
func2
=
lambda
x
:
self
.
_func
({
**
x
,
**
c_inp
.
val
})
...
...
test/test_operators/test_jax.py
View file @
404027de
...
...
@@ -86,3 +86,21 @@ def test_jax_energy(dom):
continue
pos1
=
ift
.
from_random
(
e
.
domain
)
ift
.
extra
.
assert_allclose
(
e0
(
lin
).
metric
(
pos1
),
e
(
lin
).
metric
(
pos1
))
def
test_jax_errors
():
dom
=
ift
.
UnstructuredDomain
(
2
)
mdom
=
{
"a"
:
dom
}
op
=
ift
.
JaxOperator
(
dom
,
dom
,
lambda
x
:
{
"a"
:
x
})
fld
=
ift
.
full
(
dom
,
0.
)
with
pytest
.
raises
(
TypeError
):
op
(
fld
)
op
=
ift
.
JaxOperator
(
dom
,
mdom
,
lambda
x
:
x
)
with
pytest
.
raises
(
TypeError
):
op
(
fld
)
op
=
ift
.
JaxOperator
(
dom
,
dom
,
lambda
x
:
x
[
0
])
with
pytest
.
raises
(
ValueError
):
op
(
fld
)
op
=
ift
.
JaxOperator
(
dom
,
mdom
,
lambda
x
:
{
"a"
:
x
[
0
]})
with
pytest
.
raises
(
ValueError
):
op
(
fld
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment