Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
70a698c2
Commit
70a698c2
authored
Aug 06, 2021
by
Philipp Arras
Browse files
Cosmetics
parent
80fd5456
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
70a698c2
...
...
@@ -15,11 +15,10 @@
# Author: Philipp Arras
import
numpy
as
np
from
.operator
import
Operator
from
.energy_operators
import
EnergyOperator
,
LikelihoodEnergyOperator
from
.linear_operator
import
LinearOperator
from
.endomorphic_operator
import
EndomorphicOperator
from
.energy_operators
import
LikelihoodEnergyOperator
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
try
:
import
jax
...
...
@@ -60,8 +59,8 @@ class JaxOperator(Operator):
self
.
_fwd
=
jax
.
jit
(
lambda
x
,
y
:
jax
.
jvp
(
self
.
_func
,
(
x
,),
(
y
,))[
1
])
def
apply
(
self
,
x
):
from
..sugar
import
is_linearization
,
makeField
from
..multi_domain
import
MultiDomain
from
..sugar
import
is_linearization
,
makeField
self
.
_check_input
(
x
)
if
is_linearization
(
x
):
res
,
bwd
=
self
.
_vjp
(
x
.
val
.
val
)
...
...
@@ -137,9 +136,9 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
return
self
.
_dt
,
self
.
_trafo
def
apply
(
self
,
x
):
from
..linearization
import
Linearization
from
..sugar
import
is_linearization
,
makeField
from
.simple_linear_operators
import
VdotOperator
from
..linearization
import
Linearization
self
.
_check_input
(
x
)
lin
=
is_linearization
(
x
)
val
=
x
.
val
.
val
if
lin
else
x
.
val
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment