Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
R
resolve
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Monitor
Service Desk
Analyze
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
ift
resolve
Commits
c659f803
Commit
c659f803
authored
Apr 17, 2024
by
Jakob Roth
Browse files
Options
Downloads
Patches
Plain Diff
use JAXbind for binding ducc wgridder to JAX
parent
9f6a66e2
No related branches found
No related tags found
No related merge requests found
Pipeline
#202172
passed
Apr 17, 2024
Stage: build_docker
Stage: testing
Stage: release
Stage: deploy
Changes
3
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
README.md
+1
-1
1 addition, 1 deletion
README.md
demo/imaging_resolve_jax.py
+3
-9
3 additions, 9 deletions
demo/imaging_resolve_jax.py
resolve/re/response.py
+13
-27
13 additions, 27 deletions
resolve/re/response.py
with
17 additions
and
37 deletions
README.md
+
1
−
1
View file @
c659f803
...
@@ -38,7 +38,7 @@ Optional dependencies:
...
@@ -38,7 +38,7 @@ Optional dependencies:
-
matplotlib
-
matplotlib
-
dask-ms[xarray, zarr] (for reading pfb-clean xds files)
-
dask-ms[xarray, zarr] (for reading pfb-clean xds files)
-
[
jax-finufft
](
https://github.com/flatironinstitute/jax-finufft
)
(
for
using the finufft in jax-resolve)
-
[
jax-finufft
](
https://github.com/flatironinstitute/jax-finufft
)
(
for
using the finufft in jax-resolve)
-
[
jaxlinop
](
https://git
lab.mpcdf.mpg.de/mtr/jax_linop
)
(
for
using ducc gridder in jax-resolve)
-
[
JAXbind
](
https://git
hub.com/NIFTy-PPL/JAXbind
)
(
for
using ducc gridder in jax-resolve)
## Installation
## Installation
...
...
...
...
This diff is collapsed.
Click to expand it.
demo/imaging_resolve_jax.py
+
3
−
9
View file @
c659f803
...
@@ -11,8 +11,7 @@ from matplotlib.colors import LogNorm
...
@@ -11,8 +11,7 @@ from matplotlib.colors import LogNorm
import
configparser
import
configparser
from
jax
import
random
from
jax
import
random
response
=
'
old
'
response
=
'
ducc
'
response
=
'
new
'
# response = "finu"
# response = "finu"
seed
=
42
seed
=
42
...
@@ -33,17 +32,12 @@ sky, additional = jrve.sky_model(cfg["sky"])
...
@@ -33,17 +32,12 @@ sky, additional = jrve.sky_model(cfg["sky"])
sky_sp
=
rve
.
sky_model
.
_spatial_dom
(
cfg
[
"
sky
"
])
sky_sp
=
rve
.
sky_model
.
_spatial_dom
(
cfg
[
"
sky
"
])
sky_dom
=
rve
.
default_sky_domain
(
sdom
=
sky_sp
)
sky_dom
=
rve
.
default_sky_domain
(
sdom
=
sky_sp
)
if
response
==
"
old
"
:
if
response
==
"
finu
"
:
R_rve
=
jrve
.
InterferometryResponse
(
obs
,
sky_dom
,
False
,
1e-9
,
verbosity
=
0
,
nthreads
=
8
)
signal_response
=
lambda
x
:
R_rve
(
sky
(
x
))
elif
response
==
"
finu
"
:
R_finufft
=
jrve
.
InterferometryResponseFinuFFT
(
R_finufft
=
jrve
.
InterferometryResponseFinuFFT
(
obs
,
sky_sp
.
distances
[
0
],
sky_sp
.
distances
[
1
],
1e-9
obs
,
sky_sp
.
distances
[
0
],
sky_sp
.
distances
[
1
],
1e-9
)
)
signal_response
=
lambda
x
:
R_finufft
(
sky
(
x
)[
0
,
0
,
0
,
:,
:])
signal_response
=
lambda
x
:
R_finufft
(
sky
(
x
)[
0
,
0
,
0
,
:,
:])
elif
response
==
'
new
'
:
elif
response
==
'
ducc
'
:
sky_domain_dict
=
dict
(
npix_x
=
sky_sp
.
shape
[
0
],
sky_domain_dict
=
dict
(
npix_x
=
sky_sp
.
shape
[
0
],
npix_y
=
sky_sp
.
shape
[
1
],
npix_y
=
sky_sp
.
shape
[
1
],
pixsize_x
=
sky_sp
.
distances
[
0
],
pixsize_x
=
sky_sp
.
distances
[
0
],
...
...
...
...
This diff is collapsed.
Click to expand it.
resolve/re/response.py
+
13
−
27
View file @
c659f803
...
@@ -4,6 +4,7 @@ import nifty8 as ift
...
@@ -4,6 +4,7 @@ import nifty8 as ift
from
functools
import
partial
from
functools
import
partial
from
..util
import
dtype_float2complex
from
..util
import
dtype_float2complex
from
jax.tree_util
import
Partial
def
get_binbounds
(
coordinates
):
def
get_binbounds
(
coordinates
):
if
len
(
coordinates
)
==
1
:
if
len
(
coordinates
)
==
1
:
...
@@ -158,37 +159,22 @@ def InterferometryResponseDucc(
...
@@ -158,37 +159,22 @@ def InterferometryResponseDucc(
nthreads
=
1
,
nthreads
=
1
,
verbosity
=
0
,
verbosity
=
0
,
):
):
from
ducc0.wgridder.experimental
import
dirty2vis
,
vis2dirty
from
jaxbind.contrib
import
jaxducc0
import
jax_linop
vol
=
pixsize_x
*
pixsize_y
vol
=
pixsize_x
*
pixsize_y
nvis
=
observation
.
vis
.
shape
[
1
]
_args
=
{
"
uvw
"
:
observation
.
uvw
,
"
freq
"
:
observation
.
freq
,
"
pixsize_x
"
:
pixsize_x
,
"
pixsize_y
"
:
pixsize_y
,
"
epsilon
"
:
epsilon
,
"
do_wgridding
"
:
do_wgridding
,
"
nthreads
"
:
nthreads
,
"
flip_v
"
:
True
,
"
verbosity
"
:
verbosity
,
}
def
R
(
inp
,
out
,
state
):
out
[()]
=
dirty2vis
(
dirty
=
inp
,
**
_args
)
def
Re_T
(
inp
,
out
,
state
):
out
[()]
=
vis2dirty
(
vis
=
inp
.
conj
(),
npix_x
=
npix_x
,
npix_y
=
npix_y
,
**
_args
)
def
R_abstract
(
shape
,
dtype
,
state
):
wg
=
jaxducc0
.
get_wgridder
(
return
(
nvis
,
1
),
np
.
dtype
(
np
.
complex128
)
pixsize_x
=
pixsize_x
,
pixsize_y
=
pixsize_y
,
def
R_abstract_T
(
shape
,
dtype
,
state
):
npix_x
=
npix_x
,
return
(
npix_x
,
npix_y
),
np
.
dtype
(
np
.
float64
)
npix_y
=
npix_y
,
epsilon
=
epsilon
,
do_wgridding
=
do_wgridding
,
nthreads
=
nthreads
,
)
wgridder
=
Partial
(
wg
,
observation
.
uvw
,
observation
.
freq
)
R_jax
=
jax_linop
.
get_linear_call
(
R
,
Re_T
,
R_abstract
,
R_abstract_T
)
return
lambda
x
:
vol
*
wgridder
(
x
)[
0
]
return
lambda
x
:
vol
*
R_jax
(
x
)[
0
]
def
InterferometryResponseFinuFFT
(
observation
,
pixsizex
,
pixsizey
,
epsilon
):
def
InterferometryResponseFinuFFT
(
observation
,
pixsizex
,
pixsizey
,
epsilon
):
...
...
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
sign in
to comment