Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
9aaf8783
Commit
9aaf8783
authored
May 27, 2019
by
Martin Reinecke
Browse files
first try
parent
a439a47c
Pipeline
#49995
failed with stages
in 4 minutes and 59 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
demos/bench_gridder.py
View file @
9aaf8783
...
...
@@ -27,12 +27,9 @@ for ii in range(10, 23):
img
=
ift
.
from_global_data
(
uvspace
,
img
)
t0
=
time
()
GM
=
ift
.
GridderMaker
(
uvspace
,
eps
=
1e-7
)
idx
=
GM
.
getReordering
(
uv
)
uv
=
uv
[
idx
]
vis
=
vis
[
idx
]
GM
=
ift
.
GridderMaker
(
uvspace
,
eps
=
1e-7
,
uv
=
uv
)
vis
=
ift
.
from_global_data
(
visspace
,
vis
)
op
=
GM
.
getFull
(
uv
).
adjoint
op
=
GM
.
getFull
().
adjoint
t1
=
time
()
op
(
img
).
to_global_data
()
t2
=
time
()
...
...
nifty5/library/gridder.py
View file @
9aaf8783
...
...
@@ -26,52 +26,45 @@ from ..sugar import from_global_data, makeDomain
class
GridderMaker
(
object
):
def
__init__
(
self
,
domain
,
eps
=
2e-13
):
from
nifty_gridder
import
get_w
domain
=
makeDomain
(
domain
)
if
(
len
(
domain
)
!=
1
or
not
isinstance
(
domain
[
0
],
RGSpace
)
or
not
len
(
domain
.
shape
)
==
2
):
raise
ValueError
(
"need domain with exactly one 2D RGSpace"
)
nu
,
nv
=
domain
.
shape
if
nu
%
2
!=
0
or
nv
%
2
!=
0
:
raise
ValueError
(
"dimensions must be even"
)
nu2
,
nv2
=
2
*
nu
,
2
*
nv
w
=
get_w
(
eps
)
nsafe
=
(
w
+
1
)
//
2
nu2
=
max
([
nu2
,
2
*
nsafe
])
nv2
=
max
([
nv2
,
2
*
nsafe
])
oversampled_domain
=
RGSpace
(
[
nu2
,
nv2
],
distances
=
[
1
,
1
],
harmonic
=
False
)
self
.
_eps
=
eps
self
.
_rest
=
_RestOperator
(
domain
,
oversampled_domain
,
eps
)
def
getReordering
(
self
,
uv
):
from
nifty_gridder
import
peanoindex
nu2
,
nv2
=
self
.
_rest
.
_domain
.
shape
return
peanoindex
(
uv
,
nu2
,
nv2
)
def
getGridder
(
self
,
uv
):
return
RadioGridder
(
self
.
_rest
.
domain
,
self
.
_eps
,
uv
)
def
__init__
(
self
,
dirty_domain
,
uv
,
eps
=
2e-13
):
import
nifty_gridder
dirty_domain
=
makeDomain
(
dirty_domain
)
if
(
len
(
dirty_domain
)
!=
1
or
not
isinstance
(
dirty_domain
[
0
],
RGSpace
)
or
not
len
(
dirty_domain
.
shape
)
==
2
):
raise
ValueError
(
"need dirty_domain with exactly one 2D RGSpace"
)
bl
=
nifty_gridder
.
Baselines
(
uv
,
np
.
array
([
1.
]));
nxdirty
,
nydirty
=
dirty_domain
.
shape
gconf
=
nifty_gridder
.
GridderConfig
(
nxdirty
,
nydirty
,
eps
,
1.
,
1.
)
nu
=
gconf
.
Nu
()
nv
=
gconf
.
Nv
()
idx
=
bl
.
getIndices
()
idx
=
gconf
.
reorderIndices
(
idx
,
bl
)
grid_domain
=
RGSpace
([
nu
,
nv
],
distances
=
[
1
,
1
],
harmonic
=
False
)
self
.
_rest
=
_RestOperator
(
dirty_domain
,
grid_domain
,
gconf
)
self
.
_gridder
=
RadioGridder
(
grid_domain
,
bl
,
gconf
,
idx
)
def
getGridder
(
self
):
return
self
.
_gridder
def
getRest
(
self
):
return
self
.
_rest
def
getFull
(
self
,
uv
):
return
self
.
getRest
()
@
self
.
g
etG
ridder
(
uv
)
def
getFull
(
self
):
return
self
.
getRest
()
@
self
.
_
gridder
class
_RestOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
,
oversampled_domain
,
eps
):
from
nifty_gridder
import
correction_factors
self
.
_domain
=
makeDomain
(
oversampled_domain
)
self
.
_target
=
domain
nu
,
nv
=
domain
.
shape
nu2
,
nv2
=
oversampled_domain
.
shape
fu
=
correction_factors
(
nu2
,
nu
//
2
+
1
,
eps
)
fv
=
correction_factors
(
nv2
,
nv
//
2
+
1
,
eps
)
def
__init__
(
self
,
dirty_domain
,
grid_domain
,
gconf
):
import
nifty_gridder
self
.
_domain
=
makeDomain
(
grid_domain
)
self
.
_target
=
makeDomain
(
dirty_domain
)
self
.
_gconf
=
gconf
fu
=
gconf
.
U_corrections
()
fv
=
gconf
.
V_corrections
()
nu
,
nv
=
dirty_domain
.
shape
# compute deconvolution operator
rng
=
np
.
arange
(
nu
)
k
=
np
.
minimum
(
rng
,
nu
-
rng
)
...
...
@@ -82,6 +75,7 @@ class _RestOperator(LinearOperator):
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
import
nifty_gridder
self
.
_check_input
(
x
,
mode
)
nu
,
nv
=
self
.
_target
.
shape
res
=
x
.
to_global_data
()
...
...
@@ -103,20 +97,21 @@ class _RestOperator(LinearOperator):
class
RadioGridder
(
LinearOperator
):
def
__init__
(
self
,
target
,
eps
,
uv
):
self
.
_domain
=
DomainTuple
.
make
(
UnstructuredDomain
((
uv
.
shape
[
0
],)))
self
.
_target
=
DomainTuple
.
make
(
target
)
def
__init__
(
self
,
grid_domain
,
bl
,
gconf
,
idx
):
self
.
_domain
=
DomainTuple
.
make
(
UnstructuredDomain
((
idx
.
shape
[
0
],)))
self
.
_target
=
DomainTuple
.
make
(
grid_domain
)
self
.
_bl
=
bl
self
.
_gconf
=
gconf
self
.
_idx
=
idx
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_eps
=
float
(
eps
)
self
.
_uv
=
uv
# FIXME: should we write-protect this?
def
apply
(
self
,
x
,
mode
):
from
nifty_gridder
import
to_grid
,
from_grid
import
nifty_gridder
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
nu2
,
nv2
=
self
.
_target
.
shape
res
=
to_grid
(
self
.
_
uv
,
x
.
to_global_data
()
,
nu2
,
nv2
,
self
.
_eps
)
res
=
nifty_gridder
.
ms2grid
(
self
.
_bl
,
self
.
_gconf
,
self
.
_
idx
,
x
.
to_global_data
()
.
reshape
((
-
1
,
1
))
)
else
:
res
=
from_grid
(
self
.
_uv
,
x
.
to_global_data
(),
self
.
_eps
)
res
=
nifty_gridder
.
grid2ms
(
self
.
_bl
,
self
.
_gconf
,
self
.
_idx
,
x
.
to_global_data
())
return
from_global_data
(
self
.
_tgt
(
mode
),
res
)
test/test_operators/test_nft.py
View file @
9aaf8783
...
...
@@ -39,13 +39,10 @@ def test_gridding(nu, nv, N, eps):
vis
=
np
.
random
.
randn
(
N
)
+
1j
*
np
.
random
.
randn
(
N
)
# Nifty
GM
=
ift
.
GridderMaker
(
ift
.
RGSpace
((
nu
,
nv
)),
eps
=
eps
)
# re-order for performance
idx
=
GM
.
getReordering
(
uv
)
uv
,
vis
=
uv
[
idx
],
vis
[
idx
]
GM
=
ift
.
GridderMaker
(
ift
.
RGSpace
((
nu
,
nv
)),
eps
=
eps
,
uv
=
uv
)
vis2
=
ift
.
from_global_data
(
ift
.
UnstructuredDomain
(
vis
.
shape
),
vis
)
Op
=
GM
.
getFull
(
uv
)
Op
=
GM
.
getFull
()
pynu
=
Op
(
vis2
).
to_global_data
()
# DFT
x
,
y
=
np
.
meshgrid
(
...
...
@@ -63,14 +60,12 @@ def test_gridding(nu, nv, N, eps):
def
test_build
(
nu
,
nv
,
N
,
eps
):
dom
=
ift
.
RGSpace
([
nu
,
nv
])
uv
=
np
.
random
.
rand
(
N
,
2
)
-
0.5
GM
=
ift
.
GridderMaker
(
dom
,
eps
=
eps
)
GM
=
ift
.
GridderMaker
(
dom
,
eps
=
eps
,
uv
=
uv
)
# re-order for performance
idx
=
GM
.
getReordering
(
uv
)
uv
=
uv
[
idx
]
R0
=
GM
.
getGridder
(
uv
)
R0
=
GM
.
getGridder
()
R1
=
GM
.
getRest
()
R
=
R1
@
R0
RF
=
GM
.
getFull
(
uv
)
RF
=
GM
.
getFull
()
# Consistency checks
flt
=
np
.
float64
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment