Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
559b8416
Commit
559b8416
authored
Aug 06, 2021
by
Philipp Arras
Browse files
Add JaxLinearOperator
parent
70a698c2
Pipeline
#107174
passed with stages
in 21 minutes and 4 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
559b8416
...
...
@@ -23,7 +23,7 @@ from .operator import Operator
try
:
import
jax
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
__all__
=
[
"JaxOperator"
,
"JaxLikelihoodEnergyOperator"
]
__all__
=
[
"JaxOperator"
,
"JaxLikelihoodEnergyOperator"
,
"JaxLinearOperator"
]
except
ImportError
:
__all__
=
[]
...
...
@@ -65,7 +65,7 @@ class JaxOperator(Operator):
if
is_linearization
(
x
):
res
,
bwd
=
self
.
_vjp
(
x
.
val
.
val
)
fwd
=
lambda
y
:
self
.
_fwd
(
x
.
val
.
val
,
y
)
jac
=
_Jax
Jacobian
(
self
.
_domain
,
self
.
_target
,
fwd
,
bwd
)
jac
=
_Jax
LinearOperator
(
self
.
_domain
,
self
.
_target
,
fwd
,
bwd
)
return
x
.
new
(
makeField
(
self
.
_target
,
_jax2np
(
res
)),
jac
)
res
=
_jax2np
(
self
.
_func
(
x
.
val
))
if
isinstance
(
res
,
dict
):
...
...
@@ -100,6 +100,43 @@ class JaxOperator(Operator):
return
None
,
JaxOperator
(
dom
,
self
.
_target
,
func2
)
def
JaxLinearOperator
(
domain
,
target
,
func
,
domain_dtype
):
"""Wrap a jax function as nifty linear operator.
Parameters
----------
domain : DomainTuple or MultiDomain
Domain of the operator.
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the
target.
domain_dtype:
Dtype of the domain. If `domain` is a `MultiDomain`, `domain_dtype` is
supposed to be a dictionary.
Note
----
It is the user's responsibility that func is actually a linear function. The
user can double check this with the help of `nifty8.extra.check_linear_operator`.
"""
from
..domain_tuple
import
DomainTuple
from
..sugar
import
makeDomain
domain
=
makeDomain
(
domain
)
if
isinstance
(
domain
,
DomainTuple
):
inp
=
np
.
ones
(
domain
.
shape
,
domain_dtype
)
else
:
inp
=
{
kk
:
np
.
ones
(
domain
[
kk
].
shape
,
domain_dtype
[
kk
])
for
kk
in
domain
.
keys
()}
func_transposed
=
jax
.
jit
(
jax
.
vjp
(
func
,
inp
)[
1
])
return
_JaxLinearOperator
(
domain
,
target
,
func
,
func_transposed
)
class
JaxLikelihoodEnergyOperator
(
LikelihoodEnergyOperator
):
"""Wrap a jax function as nifty likelihood energy operator.
...
...
@@ -163,7 +200,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
return
None
,
JaxLikelihoodEnergyOperator
(
dom
,
func2
,
trafo
,
dt
)
class
_Jax
Jacobian
(
LinearOperator
):
class
_Jax
LinearOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
,
func
,
func_transposed
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
...
...
test/test_operators/test_jax.py
View file @
559b8416
...
...
@@ -29,16 +29,26 @@ pmp = pytest.mark.parametrize
@
pmp
(
"dom"
,
[
ift
.
RGSpace
((
10
,
8
)),
(
ift
.
RGSpace
(
10
),
ift
.
RGSpace
(
8
))])
@
pmp
(
"func"
,
[
lambda
x
:
x
,
lambda
x
:
x
**
2
,
lambda
x
:
x
*
x
,
lambda
x
:
x
*
x
[
0
,
0
],
lambda
x
:
jnp
.
sin
(
x
),
lambda
x
:
x
*
x
.
sum
()])
@
pmp
(
"func"
,
[(
lambda
x
:
x
,
True
),
(
lambda
x
:
x
**
2
,
False
),
(
lambda
x
:
x
*
x
,
False
),
(
lambda
x
:
x
*
x
[
0
,
0
],
False
),
(
lambda
x
:
x
+
x
[
0
,
0
],
True
),
(
lambda
x
:
jnp
.
sin
(
x
),
False
),
(
lambda
x
:
x
*
x
.
sum
(),
False
),
(
lambda
x
:
x
+
x
.
sum
(),
True
)])
def
test_jax
(
dom
,
func
):
pytest
.
importorskip
(
"jax"
)
loc
=
ift
.
from_random
(
dom
)
res0
=
np
.
array
(
func
(
loc
.
val
))
op
=
ift
.
JaxOperator
(
dom
,
dom
,
func
)
f
,
linear
=
func
res0
=
np
.
array
(
f
(
loc
.
val
))
op
=
ift
.
JaxOperator
(
dom
,
dom
,
f
)
np
.
testing
.
assert_allclose
(
res0
,
op
(
loc
).
val
)
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
))
op
=
ift
.
JaxLinearOperator
(
dom
,
dom
,
f
,
np
.
float64
)
if
linear
:
ift
.
extra
.
check_linear_operator
(
op
)
else
:
with
pytest
.
raises
(
AssertionError
):
ift
.
extra
.
check_linear_operator
(
op
)
def
test_mf_jax
():
pytest
.
importorskip
(
"jax"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment