Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
44ee2914
Commit
44ee2914
authored
Aug 11, 2021
by
Martin Reinecke
Browse files
Merge branch 'jax_linear_operator' into 'NIFTy_8'
Jax linear operator See merge request
!673
parents
80fd5456
1b646049
Pipeline
#107709
passed with stages
in 35 minutes and 51 seconds
Changes
2
Pipelines
6
Hide whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
44ee2914
...
@@ -14,17 +14,18 @@
...
@@ -14,17 +14,18 @@
# Copyright(C) 2021 Max-Planck-Society
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
# Author: Philipp Arras
from
types
import
SimpleNamespace
import
numpy
as
np
import
numpy
as
np
from
.operator
import
Operator
from
.energy_operators
import
EnergyOperator
,
LikelihoodEnergyOperator
from
.linear_operator
import
LinearOperator
from
.endomorphic_operator
import
EndomorphicOperator
from
.energy_operators
import
LikelihoodEnergyOperator
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
try
:
try
:
import
jax
import
jax
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
__all__
=
[
"JaxOperator"
,
"JaxLikelihoodEnergyOperator"
]
__all__
=
[
"JaxOperator"
,
"JaxLikelihoodEnergyOperator"
,
"JaxLinearOperator"
]
except
ImportError
:
except
ImportError
:
__all__
=
[]
__all__
=
[]
...
@@ -48,7 +49,7 @@ class JaxOperator(Operator):
...
@@ -48,7 +49,7 @@ class JaxOperator(Operator):
func : callable
func : callable
The jax function that is evaluated by the operator. It has to be
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
implemented in terms of `jax.numpy` calls. If `domain` is a
`Domain
Tuple
`, `func` takes a `dict` as argument and like-wise for the
`
Multi
Domain`, `func` takes a `dict` as argument and like-wise for the
target.
target.
"""
"""
def
__init__
(
self
,
domain
,
target
,
func
):
def
__init__
(
self
,
domain
,
target
,
func
):
...
@@ -60,13 +61,13 @@ class JaxOperator(Operator):
...
@@ -60,13 +61,13 @@ class JaxOperator(Operator):
self
.
_fwd
=
jax
.
jit
(
lambda
x
,
y
:
jax
.
jvp
(
self
.
_func
,
(
x
,),
(
y
,))[
1
])
self
.
_fwd
=
jax
.
jit
(
lambda
x
,
y
:
jax
.
jvp
(
self
.
_func
,
(
x
,),
(
y
,))[
1
])
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
from
..sugar
import
is_linearization
,
makeField
from
..multi_domain
import
MultiDomain
from
..multi_domain
import
MultiDomain
from
..sugar
import
is_linearization
,
makeField
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
if
is_linearization
(
x
):
if
is_linearization
(
x
):
res
,
bwd
=
self
.
_vjp
(
x
.
val
.
val
)
res
,
bwd
=
self
.
_vjp
(
x
.
val
.
val
)
fwd
=
lambda
y
:
self
.
_fwd
(
x
.
val
.
val
,
y
)
fwd
=
lambda
y
:
self
.
_fwd
(
x
.
val
.
val
,
y
)
jac
=
_
Jax
Jacobian
(
self
.
_domain
,
self
.
_target
,
fwd
,
bwd
)
jac
=
Jax
LinearOperator
(
self
.
_domain
,
self
.
_target
,
fwd
,
func_T
=
bwd
)
return
x
.
new
(
makeField
(
self
.
_target
,
_jax2np
(
res
)),
jac
)
return
x
.
new
(
makeField
(
self
.
_target
,
_jax2np
(
res
)),
jac
)
res
=
_jax2np
(
self
.
_func
(
x
.
val
))
res
=
_jax2np
(
self
.
_func
(
x
.
val
))
if
isinstance
(
res
,
dict
):
if
isinstance
(
res
,
dict
):
...
@@ -101,6 +102,70 @@ class JaxOperator(Operator):
...
@@ -101,6 +102,70 @@ class JaxOperator(Operator):
return
None
,
JaxOperator
(
dom
,
self
.
_target
,
func2
)
return
None
,
JaxOperator
(
dom
,
self
.
_target
,
func2
)
class
JaxLinearOperator
(
LinearOperator
):
"""Wrap a jax function as nifty linear operator.
Parameters
----------
domain : DomainTuple or MultiDomain
Domain of the operator.
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target.
func_T : callable
The jax function that implements the transposed action of the operator.
If None, jax computes the adjoint. Note that this is *not* the adjoint
action. Default: None.
domain_dtype:
Needs to be set if `func_transposed` is None. Otherwise it does not have
an effect. Dtype of the domain. If `domain` is a `MultiDomain`,
`domain_dtype` is supposed to be a dictionary. Default: None.
Note
----
It is the user's responsibility that func is actually a linear function. The
user can double check this with the help of
`nifty8.extra.check_linear_operator`.
"""
def
__init__
(
self
,
domain
,
target
,
func
,
domain_dtype
=
None
,
func_T
=
None
):
from
..domain_tuple
import
DomainTuple
from
..sugar
import
makeDomain
domain
=
makeDomain
(
domain
)
if
domain_dtype
is
not
None
and
func_T
is
None
:
if
isinstance
(
domain
,
DomainTuple
):
inp
=
SimpleNamespace
(
shape
=
domain
.
shape
,
dtype
=
domain_dtype
)
else
:
inp
=
{
kk
:
SimpleNameSpace
(
shape
=
domain
[
kk
].
shape
,
dtype
=
domain_dtype
[
kk
])
for
kk
in
domain
.
keys
()}
func_T
=
jax
.
jit
(
jax
.
linear_transpose
(
func
,
inp
))
elif
domain_dtype
is
None
and
func_T
is
not
None
:
pass
else
:
raise
ValueError
(
"Either domain_dtype or func_T have to be not None."
)
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_func
=
func
self
.
_func_T
=
func_T
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
from
..sugar
import
makeField
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
fx
=
self
.
_func
(
x
.
val
)
return
makeField
(
self
.
_target
,
_jax2np
(
fx
))
fx
=
self
.
_func_T
(
x
.
conjugate
().
val
)[
0
]
return
makeField
(
self
.
_domain
,
_jax2np
(
fx
)).
conjugate
()
class
JaxLikelihoodEnergyOperator
(
LikelihoodEnergyOperator
):
class
JaxLikelihoodEnergyOperator
(
LikelihoodEnergyOperator
):
"""Wrap a jax function as nifty likelihood energy operator.
"""Wrap a jax function as nifty likelihood energy operator.
...
@@ -112,7 +177,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
...
@@ -112,7 +177,7 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
func : callable
func : callable
The jax function that is evaluated by the operator. It has to be
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
implemented in terms of `jax.numpy` calls. If `domain` is a
`Domain
Tuple
`, `func` takes a `dict` as argument and like-wise for the
`
Multi
Domain`, `func` takes a `dict` as argument and like-wise for the
target. It needs to map to a scalar.
target. It needs to map to a scalar.
transformation : Operator, optional
transformation : Operator, optional
...
@@ -137,9 +202,9 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
...
@@ -137,9 +202,9 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
return
self
.
_dt
,
self
.
_trafo
return
self
.
_dt
,
self
.
_trafo
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
from
..linearization
import
Linearization
from
..sugar
import
is_linearization
,
makeField
from
..sugar
import
is_linearization
,
makeField
from
.simple_linear_operators
import
VdotOperator
from
.simple_linear_operators
import
VdotOperator
from
..linearization
import
Linearization
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
lin
=
is_linearization
(
x
)
lin
=
is_linearization
(
x
)
val
=
x
.
val
.
val
if
lin
else
x
.
val
val
=
x
.
val
.
val
if
lin
else
x
.
val
...
@@ -162,22 +227,3 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
...
@@ -162,22 +227,3 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
else
:
else
:
dt
=
self
.
_dt
dt
=
self
.
_dt
return
None
,
JaxLikelihoodEnergyOperator
(
dom
,
func2
,
trafo
,
dt
)
return
None
,
JaxLikelihoodEnergyOperator
(
dom
,
func2
,
trafo
,
dt
)
class
_JaxJacobian
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
,
func
,
func_transposed
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_func
=
func
self
.
_func_transposed
=
func_transposed
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
from
..sugar
import
makeField
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
fx
=
self
.
_func
(
x
.
val
)
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
))
fx
=
self
.
_func_transposed
(
x
.
conjugate
().
val
)[
0
]
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
)).
conjugate
()
test/test_operators/test_jax.py
View file @
44ee2914
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
import
nifty8
as
ift
import
nifty8
as
ift
import
numpy
as
np
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
pytest
import
pytest
try
:
try
:
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -29,16 +28,26 @@ pmp = pytest.mark.parametrize
...
@@ -29,16 +28,26 @@ pmp = pytest.mark.parametrize
@
pmp
(
"dom"
,
[
ift
.
RGSpace
((
10
,
8
)),
(
ift
.
RGSpace
(
10
),
ift
.
RGSpace
(
8
))])
@
pmp
(
"dom"
,
[
ift
.
RGSpace
((
10
,
8
)),
(
ift
.
RGSpace
(
10
),
ift
.
RGSpace
(
8
))])
@
pmp
(
"func"
,
[
lambda
x
:
x
,
lambda
x
:
x
**
2
,
lambda
x
:
x
*
x
,
lambda
x
:
x
*
x
[
0
,
0
],
@
pmp
(
"func"
,
[(
lambda
x
:
x
,
True
),
(
lambda
x
:
x
**
2
,
False
),
(
lambda
x
:
x
*
x
,
False
),
lambda
x
:
jnp
.
sin
(
x
),
lambda
x
:
x
*
x
.
sum
()])
(
lambda
x
:
x
*
x
[
0
,
0
],
False
),
(
lambda
x
:
x
+
x
[
0
,
0
],
True
),
(
lambda
x
:
jnp
.
sin
(
x
),
False
),
(
lambda
x
:
x
*
x
.
sum
(),
False
),
(
lambda
x
:
x
+
x
.
sum
(),
True
)])
def
test_jax
(
dom
,
func
):
def
test_jax
(
dom
,
func
):
pytest
.
importorskip
(
"jax"
)
pytest
.
importorskip
(
"jax"
)
loc
=
ift
.
from_random
(
dom
)
loc
=
ift
.
from_random
(
dom
)
res0
=
np
.
array
(
func
(
loc
.
val
))
f
,
linear
=
func
op
=
ift
.
JaxOperator
(
dom
,
dom
,
func
)
res0
=
np
.
array
(
f
(
loc
.
val
))
op
=
ift
.
JaxOperator
(
dom
,
dom
,
f
)
np
.
testing
.
assert_allclose
(
res0
,
op
(
loc
).
val
)
np
.
testing
.
assert_allclose
(
res0
,
op
(
loc
).
val
)
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
))
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
))
op
=
ift
.
JaxLinearOperator
(
dom
,
dom
,
f
,
np
.
float64
)
if
linear
:
ift
.
extra
.
check_linear_operator
(
op
)
else
:
with
pytest
.
raises
(
Exception
):
ift
.
extra
.
check_linear_operator
(
op
)
def
test_mf_jax
():
def
test_mf_jax
():
pytest
.
importorskip
(
"jax"
)
pytest
.
importorskip
(
"jax"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment