Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
f557057f
Commit
f557057f
authored
Jun 30, 2021
by
Philipp Arras
Browse files
Use jax.value_and_grad for energy operator
parent
272829d1
Pipeline
#104760
passed with stages
in 16 minutes and 18 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
f557057f
...
...
@@ -101,7 +101,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_func
=
jax
.
jit
(
func
)
self
.
_grad
=
jax
.
jit
(
jax
.
grad
(
func
))
self
.
_
val_and_
grad
=
jax
.
jit
(
jax
.
value_and_
grad
(
func
))
self
.
_dt
=
sampling_dtype
self
.
_trafo
=
transformation
...
...
@@ -118,11 +118,11 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
self
.
_check_input
(
x
)
lin
=
is_linearization
(
x
)
val
=
x
.
val
.
val
if
lin
else
x
.
val
res
=
makeField
(
self
.
_target
,
_jax2np
(
self
.
_func
(
val
)))
if
not
lin
:
return
res
jac
=
VdotOperator
(
makeField
(
self
.
_domain
,
_jax2np
(
self
.
_grad
(
val
))))
res
=
x
.
new
(
res
,
jac
)
return
makeField
(
self
.
_target
,
_jax2np
(
self
.
_func
(
val
)))
res
,
grad
=
self
.
_val_and_grad
(
val
)
jac
=
VdotOperator
(
makeField
(
self
.
_domain
,
_jax2np
(
grad
)))
res
=
x
.
new
(
makeField
(
self
.
_target
,
_jax2np
(
res
)),
jac
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment