Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
1aae596d
Commit
1aae596d
authored
Jun 07, 2021
by
Philipp Arras
Browse files
Add wrapper for jax
parent
94171dce
Changes
5
Hide whitespace changes
Inline
Side-by-side
Dockerfile
View file @
1aae596d
...
...
@@ -12,9 +12,7 @@ RUN apt-get update && apt-get install -y \
# Optional NIFTy dependencies
python3-mpi4py python3-matplotlib \
# more optional NIFTy dependencies
&& pip3 install ducc0 \
&& pip3 install finufft \
&& pip3 install jupyter \
&& pip3 install ducc0 finufft jupyter jax jaxlib \
&& rm -rf /var/lib/apt/lists/*
# Set matplotlib backend
...
...
README.md
View file @
1aae596d
...
...
@@ -53,6 +53,7 @@ Optional dependencies:
harmonic transforms, and radio interferometry gridding support
-
[
mpi4py
](
https://mpi4py.scipy.org
)
(
for
MPI-parallel execution)
-
[
matplotlib
](
https://matplotlib.org/
)
(
for
field plotting)
-
[
jax
](
https://github.com/google/jax
)
(
for
implementing operators with jax)
### Sources
...
...
@@ -79,6 +80,8 @@ The DUCC0 package is installed via:
pip3 install ducc0
For installing jax refer to
[
google/jax:README#Installation
](
https://github.com/google/jax#installation
)
.
If this library is present, NIFTy will detect it automatically and prefer
`ducc0.fft`
over SciPy's FFT. The underlying code is actually the same, but
DUCC's FFT is compiled with optimizations for the host CPU and can provide
...
...
src/__init__.py
View file @
1aae596d
...
...
@@ -54,6 +54,7 @@ from .operators.energy_operators import (
from
.operators.convolution_operators
import
FuncConvolutionOperator
from
.operators.normal_operators
import
NormalTransform
,
LognormalTransform
from
.operators.multifield2vector
import
Multifield2Vector
from
.operators.jax_operator
import
*
from
.probing
import
probe_with_posterior_samples
,
probe_diagonal
,
\
StatCalculator
,
approximation2endo
...
...
src/operators/jax_operator.py
0 → 100644
View file @
1aae596d
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import
numpy
as
np
from
.operator
import
Operator
from
.linear_operator
import
LinearOperator
try
:
import
jax
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
__all__
=
[
"JaxOperator"
]
except
ImportError
:
__all__
=
[]
def
_jax2np
(
obj
):
if
isinstance
(
obj
,
dict
):
return
{
kk
:
np
.
array
(
vv
)
for
kk
,
vv
in
obj
.
items
()}
return
np
.
array
(
obj
)
class
JaxOperator
(
Operator
):
"""Wrap a jax function as nifty operator.
Parameters
----------
domain : DomainTuple or MultiDomain
Domain of the operator.
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the
target.
"""
def
__init__
(
self
,
domain
,
target
,
func
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_func
=
jax
.
jit
(
func
)
def
apply
(
self
,
x
):
from
..sugar
import
is_linearization
,
makeField
self
.
_check_input
(
x
)
if
is_linearization
(
x
):
res
,
bwd
=
jax
.
vjp
(
self
.
_func
,
x
.
val
.
val
)
fwd
=
lambda
y
:
jax
.
jvp
(
self
.
_func
,
(
x
.
val
.
val
,),
(
y
,))[
1
]
jac
=
_JaxJacobian
(
self
.
_domain
,
self
.
_target
,
fwd
,
bwd
)
return
x
.
new
(
makeField
(
self
.
_target
,
_jax2np
(
res
)),
jac
)
return
makeField
(
self
.
_target
,
_jax2np
(
self
.
_func
(
x
.
val
)))
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
func2
=
lambda
x
:
self
.
_func
({
**
x
,
**
c_inp
.
val
})
dom
=
{
kk
:
vv
for
kk
,
vv
in
self
.
_domain
.
items
()
if
kk
not
in
c_inp
.
keys
()}
return
None
,
JaxOperator
(
dom
,
self
.
_target
,
func2
)
class
_JaxJacobian
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
,
func
,
adjfunc
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_func
=
func
self
.
_adjfunc
=
adjfunc
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
from
..sugar
import
makeField
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
fx
=
self
.
_func
(
x
.
val
)
else
:
fx
=
self
.
_adjfunc
(
x
.
val
)[
0
]
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
))
test/test_operators/test_jax.py
0 → 100644
View file @
1aae596d
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import
nifty7
as
ift
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
pytest
try
:
import
jax.numpy
as
jnp
_skip
=
False
except
ImportError
:
import
numpy
as
np
_skip
=
True
from
..common
import
setup_function
,
teardown_function
pmp
=
pytest
.
mark
.
parametrize
@
pmp
(
"dom"
,
[
ift
.
RGSpace
((
10
,
8
)),
(
ift
.
RGSpace
(
10
),
ift
.
RGSpace
(
8
))])
@
pmp
(
"func"
,
[
lambda
x
:
x
,
lambda
x
:
x
**
2
,
lambda
x
:
x
*
x
,
lambda
x
:
x
*
x
[
0
,
0
],
lambda
x
:
jnp
.
sin
(
x
),
lambda
x
:
x
*
x
.
sum
()])
def
test_jax
(
dom
,
func
):
if
_skip
:
pytest
.
skip
()
loc
=
ift
.
from_random
(
dom
)
res0
=
np
.
array
(
func
(
loc
.
val
))
op
=
ift
.
JaxOperator
(
dom
,
dom
,
func
)
np
.
testing
.
assert_allclose
(
res0
,
op
(
loc
).
val
)
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
))
def
test_mf_jax
():
if
_skip
:
pytest
.
skip
()
dom
=
ift
.
makeDomain
({
"a"
:
ift
.
RGSpace
(
10
),
"b"
:
ift
.
UnstructuredDomain
(
2
)})
func
=
lambda
x
:
x
[
"a"
]
*
x
[
"b"
][
0
]
op
=
ift
.
JaxOperator
(
dom
,
dom
[
"a"
],
func
)
loc
=
ift
.
from_random
(
op
.
domain
)
np
.
testing
.
assert_allclose
(
np
.
array
(
func
(
loc
.
val
)),
op
(
loc
).
val
)
ift
.
extra
.
check_operator
(
op
,
loc
)
func
=
lambda
x
:
{
"a"
:
jnp
.
full
(
dom
[
"a"
].
shape
,
2.
)
*
x
[
0
]
*
x
[
1
],
"b"
:
jnp
.
full
(
dom
[
"b"
].
shape
,
1.
)
*
jnp
.
exp
(
x
[
0
])}
op
=
ift
.
JaxOperator
(
dom
[
"b"
],
dom
,
func
)
loc
=
ift
.
from_random
(
op
.
domain
)
for
kk
in
dom
.
keys
():
np
.
testing
.
assert_allclose
(
np
.
array
(
func
(
loc
.
val
)[
kk
]),
op
(
loc
)[
kk
].
val
)
ift
.
extra
.
check_operator
(
op
,
loc
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment