Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
R
resolve
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Monitor
Service Desk
Analyze
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
ift
resolve
Commits
aaa56ce4
"nomad/atomutils.py" did not exist on "80f6942a8e8f6af8240377c4f7b5ac825de3a5d4"
Commit
aaa56ce4
authored
1 year ago
by
Jakob Roth
Browse files
Options
Downloads
Patches
Plain Diff
resolve.re: unify ducc and finufft response
parent
8c8811c7
No related branches found
No related tags found
No related merge requests found
Pipeline
#202537
passed
1 year ago
Stage: build_docker
Stage: testing
Stage: release
Stage: deploy
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
demo/imaging_resolve_jax.py
+14
-20
14 additions, 20 deletions
demo/imaging_resolve_jax.py
resolve/re/__init__.py
+1
-1
1 addition, 1 deletion
resolve/re/__init__.py
resolve/re/response.py
+53
-55
53 additions, 55 deletions
resolve/re/response.py
with
68 additions
and
76 deletions
demo/imaging_resolve_jax.py
+
14
−
20
View file @
aaa56ce4
...
...
@@ -11,8 +11,9 @@ from matplotlib.colors import LogNorm
import
configparser
from
jax
import
random
response
=
'
ducc
'
# response = "finu"
# choose between ducc0 and finufft backend
response
=
'
ducc0
'
# response = "finufft"
seed
=
42
key
=
random
.
PRNGKey
(
seed
)
...
...
@@ -21,7 +22,7 @@ jax.config.update("jax_enable_x64", True)
obs
=
rve
.
Observation
.
load
(
"
CYG-ALL-2052-2MHZ_RESOLVE_float64.npz
"
)
obs
=
obs
.
restrict_to_stokesi
()
#
obs = obs.average_stokesi()
obs
=
obs
.
average_stokesi
()
obs
.
_weight
=
0.1
*
obs
.
_weight
# scale weights, as they are wrong for this specific dataset
cfg
=
configparser
.
ConfigParser
()
cfg
.
read
(
"
cygnusa_2ghz.cfg
"
)
...
...
@@ -32,23 +33,16 @@ sky, additional = jrve.sky_model(cfg["sky"])
sky_sp
=
rve
.
sky_model
.
_spatial_dom
(
cfg
[
"
sky
"
])
sky_dom
=
rve
.
default_sky_domain
(
sdom
=
sky_sp
)
if
response
==
"
finu
"
:
R_finufft
=
jrve
.
InterferometryResponseFinuFFT
(
obs
,
sky_sp
.
distances
[
0
],
sky_sp
.
distances
[
1
],
1e-9
)
signal_response
=
lambda
x
:
R_finufft
(
sky
(
x
)[
0
,
0
,
0
,
:,
:])
elif
response
==
'
ducc
'
:
sky_domain_dict
=
dict
(
npix_x
=
sky_sp
.
shape
[
0
],
npix_y
=
sky_sp
.
shape
[
1
],
pixsize_x
=
sky_sp
.
distances
[
0
],
pixsize_y
=
sky_sp
.
distances
[
1
],
pol_labels
=
[
'
I
'
],
times
=
[
0.
],
freqs
=
[
0.
])
R_new
=
jrve
.
InterferometryResponse
(
obs
,
sky_domain_dict
,
False
,
1e-9
)
signal_response
=
lambda
x
:
R_new
(
sky
(
x
))
else
:
raise
ValueError
()
sky_domain_dict
=
dict
(
npix_x
=
sky_sp
.
shape
[
0
],
npix_y
=
sky_sp
.
shape
[
1
],
pixsize_x
=
sky_sp
.
distances
[
0
],
pixsize_y
=
sky_sp
.
distances
[
1
],
pol_labels
=
[
'
I
'
],
times
=
[
0.
],
freqs
=
[
0.
])
R_new
=
jrve
.
InterferometryResponse
(
obs
,
sky_domain_dict
,
False
,
1e-9
,
backend
=
response
)
signal_response
=
lambda
x
:
R_new
(
sky
(
x
))
nll
=
jft
.
Gaussian
(
obs
.
vis
.
val
,
obs
.
weight
.
val
).
amend
(
signal_response
)
...
...
This diff is collapsed.
Click to expand it.
resolve/re/__init__.py
+
1
−
1
View file @
aaa56ce4
from
.sky_model
import
sky_model_diffuse
,
sky_model_points
,
sky_model
from
.response
import
InterferometryResponse
,
InterferometryResponseFinuFFT
,
InterferometryResponseDucc
,
InterferometryResponseOld
\ No newline at end of file
from
.response
import
InterferometryResponse
,
InterferometryResponseFinuFFT
,
InterferometryResponseDucc
\ No newline at end of file
This diff is collapsed.
Click to expand it.
resolve/re/response.py
+
53
−
55
View file @
aaa56ce4
...
...
@@ -6,24 +6,25 @@ from functools import partial
from
..util
import
dtype_float2complex
from
jax.tree_util
import
Partial
def
get_binbounds
(
coordinates
):
if
len
(
coordinates
)
==
1
:
return
np
.
array
([
-
np
.
inf
,
np
.
inf
])
return
np
.
array
([
-
np
.
inf
,
np
.
inf
])
c
=
np
.
array
(
coordinates
)
bounds
=
np
.
empty
(
self
.
size
+
1
)
bounds
[
1
:
-
1
]
=
c
[:
-
1
]
+
0.5
*
np
.
diff
(
c
)
bounds
[
0
]
=
c
[
0
]
-
0.5
*
(
c
[
1
]
-
c
[
0
])
bounds
[
-
1
]
=
c
[
-
1
]
+
0.5
*
(
c
[
-
1
]
-
c
[
-
2
])
bounds
[
1
:
-
1
]
=
c
[:
-
1
]
+
0.5
*
np
.
diff
(
c
)
bounds
[
0
]
=
c
[
0
]
-
0.5
*
(
c
[
1
]
-
c
[
0
])
bounds
[
-
1
]
=
c
[
-
1
]
+
0.5
*
(
c
[
-
1
]
-
c
[
-
2
])
return
bounds
def
convert_polarization
(
inp
,
inp_pol
,
out_pol
):
if
inp_pol
==
(
'
I
'
,):
if
out_pol
==
(
'
LL
'
,
'
RR
'
)
or
out_pol
==
(
'
XX
'
,
'
YY
'
):
if
inp_pol
==
(
"
I
"
,):
if
out_pol
==
(
"
LL
"
,
"
RR
"
)
or
out_pol
==
(
"
XX
"
,
"
YY
"
):
new_shp
=
list
(
inp
.
shape
)
new_shp
[
0
]
=
2
return
jnp
.
broadcast_to
(
inp
,
new_shp
)
if
len
(
out_pol
)
==
1
and
out_pol
[
0
]
in
(
'
I
'
,
'
RR
'
,
'
LL
'
,
'
XX
'
,
'
yy
'
):
if
len
(
out_pol
)
==
1
and
out_pol
[
0
]
in
(
"
I
"
,
"
RR
"
,
"
LL
"
,
"
XX
"
,
"
yy
"
):
return
inp
err
=
f
"
conversion of polarization
{
inp_pol
}
to
{
out_pol
}
not implemented. Please implement!
"
raise
NotImplementedError
(
err
)
...
...
@@ -36,6 +37,7 @@ def InterferometryResponse(
epsilon
,
nthreads
=
1
,
verbosity
=
0
,
backend
=
"
ducc0
"
,
):
"""
Returns a function computing the radio interferometric response
...
...
@@ -45,6 +47,8 @@ def InterferometryResponse(
The observation for which the response should compute model visibilities
sky_domain_dict: dict
A dictionary providing information about the discretization of the sky.
do_wgridding : bool
Whether to perform wgridding.
epsilon: float
The numerical accuracy with which to evaluate the response.
nthreads: int, optional
...
...
@@ -52,39 +56,60 @@ def InterferometryResponse(
verbosity: int, optional
If set to 1 prints information about the setup and performance of the
response.
backend : string
If `ducc0` use ducc0 wgridder. If `finufft` use finufft to compute response.
"""
npix_x
=
sky_domain_dict
[
'
npix_x
'
]
npix_y
=
sky_domain_dict
[
'
npix_y
'
]
pixsize_x
=
sky_domain_dict
[
'
pixsize_x
'
]
pixsize_y
=
sky_domain_dict
[
'
pixsize_y
'
]
if
do_wgridding
and
backend
==
"
finufft
"
:
raise
RuntimeError
(
"
Cannot do wgridding with backend finufft.
"
)
npix_x
=
sky_domain_dict
[
"
npix_x
"
]
npix_y
=
sky_domain_dict
[
"
npix_y
"
]
pixsize_x
=
sky_domain_dict
[
"
pixsize_x
"
]
pixsize_y
=
sky_domain_dict
[
"
pixsize_y
"
]
n_pol
=
len
(
sky_domain_dict
[
'
pol_labels
'
])
n_pol
=
len
(
sky_domain_dict
[
"
pol_labels
"
])
# compute bins for time and freq
n_times
=
len
(
sky_domain_dict
[
'
times
'
])
bb_times
=
get_binbounds
(
sky_domain_dict
[
'
times
'
])
n_times
=
len
(
sky_domain_dict
[
"
times
"
])
bb_times
=
get_binbounds
(
sky_domain_dict
[
"
times
"
])
n_freqs
=
len
(
sky_domain_dict
[
'
freqs
'
])
bb_freqs
=
get_binbounds
(
sky_domain_dict
[
'
freqs
'
])
n_freqs
=
len
(
sky_domain_dict
[
"
freqs
"
])
bb_freqs
=
get_binbounds
(
sky_domain_dict
[
"
freqs
"
])
# build responses for: time binds, freq bins
sr
=
[]
row_indices
,
freq_indices
=
[],
[]
for
t
in
range
(
n_times
):
sr_tmp
,
t_tmp
,
f_tmp
=
[],
[],
[]
if
tuple
(
bb_times
[
t
:
t
+
2
])
==
(
-
np
.
inf
,
np
.
inf
):
if
tuple
(
bb_times
[
t
:
t
+
2
])
==
(
-
np
.
inf
,
np
.
inf
):
oo
=
observation
tind
=
slice
(
None
)
else
:
oo
,
tind
=
observation
.
restrict_by_time
(
bb_times
[
t
],
bb_times
[
t
+
1
],
True
)
oo
,
tind
=
observation
.
restrict_by_time
(
bb_times
[
t
],
bb_times
[
t
+
1
],
True
)
for
f
in
range
(
n_freqs
):
ooo
,
find
=
oo
.
restrict_by_freq
(
bb_freqs
[
f
],
bb_freqs
[
f
+
1
],
True
)
ooo
,
find
=
oo
.
restrict_by_freq
(
bb_freqs
[
f
],
bb_freqs
[
f
+
1
],
True
)
if
any
(
np
.
array
(
ooo
.
vis
.
shape
)
==
0
):
rrr
=
None
else
:
rrr
=
InterferometryResponseDucc
(
ooo
,
npix_x
,
npix_y
,
pixsize_x
,
pixsize_y
,
do_wgridding
,
epsilon
,
nthreads
,
verbosity
)
if
backend
==
"
ducc0
"
:
rrr
=
InterferometryResponseDucc
(
ooo
,
npix_x
,
npix_y
,
pixsize_x
,
pixsize_y
,
do_wgridding
,
epsilon
,
nthreads
,
verbosity
,
)
elif
backend
==
"
finufft
"
:
rrr
=
InterferometryResponseFinuFFT
(
ooo
,
pixsize_x
,
pixsize_y
,
epsilon
)
else
:
err
=
f
"
backend must be `ducc0` or `finufft` not
{
backend
}
"
raise
ValueError
(
err
)
sr_tmp
.
append
(
rrr
)
t_tmp
.
append
(
tind
)
...
...
@@ -93,18 +118,18 @@ def InterferometryResponse(
row_indices
.
append
(
t_tmp
)
freq_indices
.
append
(
f_tmp
)
target_shape
=
(
n_pol
,
)
+
tuple
(
observation
.
vis
.
shape
[
1
:])
target_shape
=
(
n_pol
,)
+
tuple
(
observation
.
vis
.
shape
[
1
:])
foo
=
np
.
zeros
(
target_shape
,
np
.
int8
)
for
pp
in
range
(
n_pol
):
for
tt
in
range
(
n_times
):
for
ff
in
range
(
n_freqs
):
foo
[
pp
,
row_indices
[
tt
][
ff
],
freq_indices
[
tt
][
ff
]]
=
1.
foo
[
pp
,
row_indices
[
tt
][
ff
],
freq_indices
[
tt
][
ff
]]
=
1.
0
if
np
.
any
(
foo
==
0
):
raise
RuntimeError
(
"
This should not happen. Please report.
"
)
inp_pol
=
tuple
(
sky_domain_dict
[
'
pol_labels
'
])
inp_pol
=
tuple
(
sky_domain_dict
[
"
pol_labels
"
])
out_pol
=
observation
.
vis
.
domain
[
0
].
labels
def
apply_R
(
sky
):
res
=
jnp
.
empty
(
target_shape
,
dtype_float2complex
(
sky
.
dtype
))
for
pp
in
range
(
sky
.
shape
[
0
]):
...
...
@@ -120,33 +145,6 @@ def InterferometryResponse(
return
apply_R
def
InterferometryResponseOld
(
observation
,
domain
,
do_wgridding
,
epsilon
,
verbosity
=
0
,
nthreads
=
1
):
import
jax_linop
from
..response
import
InterferometryResponse
R_old
=
InterferometryResponse
(
observation
,
domain
,
do_wgridding
,
epsilon
,
verbosity
,
nthreads
)
def
R
(
inp
,
out
,
state
):
inp
=
ift
.
makeField
(
R_old
.
domain
,
inp
)
out
[()]
=
R_old
(
inp
).
val
def
Re_T
(
inp
,
out
,
state
):
inp
=
ift
.
makeField
(
R_old
.
target
,
inp
.
conj
())
out
[()]
=
R_old
.
adjoint
(
inp
).
val
.
conj
()
def
R_abstract
(
shape
,
dtype
,
state
):
return
R_old
.
target
.
shape
,
np
.
dtype
(
np
.
complex128
)
def
R_abstract_T
(
shape
,
dtype
,
state
):
return
R_old
.
domain
.
shape
,
np
.
dtype
(
np
.
float64
)
R_jax
=
jax_linop
.
get_linear_call
(
R
,
Re_T
,
R_abstract
,
R_abstract_T
)
return
lambda
x
:
R_jax
(
x
)[
0
]
def
InterferometryResponseDucc
(
observation
,
...
...
@@ -195,7 +193,7 @@ def InterferometryResponseFinuFFT(observation, pixsizex, pixsizey, epsilon):
def
apply_finufft
(
inp
,
u
,
v
,
eps
):
res
=
vol
*
nufft2
(
inp
.
astype
(
np
.
complex128
),
u
,
v
,
eps
=
eps
)
return
jnp
.
expand_dims
(
res
.
reshape
(
-
1
,
len
(
freq
))
,
0
)
return
res
.
reshape
(
-
1
,
len
(
freq
))
R
=
p
artial
(
apply_finufft
,
u
=
u_finu
,
v
=
v_finu
,
eps
=
epsilon
)
R
=
P
artial
(
apply_finufft
,
u
=
u_finu
,
v
=
v_finu
,
eps
=
epsilon
)
return
R
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment