Skip to content
Snippets Groups Projects

Draft: Clean experiments

Open Philipp Arras requested to merge clean_experiments into master
Files
9
+ 224
0
import resolve as rve
import ducc0
import numpy as np
import nifty8 as ift
import sys
npix = 1000
npix = 2 * ducc0.fft.good_size(npix // 2)
oversampling_factor = 4 # FIXME Why does this need to be so large?
minor_conv = {"rel first peak": 0.2, "max components": 3000, "gain": 0.8}
nmajor = 10
weighting_scheme = "uniform"
oversampling_highres = 2
do_wgridding = True
epsilon = 1e-8
nthreads = 6
if True:
do_wgridding = False
epsilon = 1e-4
nthreads = 6
def highres_domain_and_lowrew_mask(dom, oversampling, xmin, xmax, ymin, ymax):
rve.assert_sky_domain(dom)
_, _, _, sdom = dom
if isinstance(xmin, str):
xmin = rve.str2rad(xmin)
if isinstance(xmax, str):
xmax = rve.str2rad(xmax)
if isinstance(ymin, str):
ymin = rve.str2rad(ymin)
if isinstance(ymax, str):
ymax = rve.str2rad(ymax)
if not isinstance(oversampling, int):
s = f"`oversampling` needs to be int. Got: {oversampling}"
raise TypeError(s)
Dstx, Dsty = sdom.distances
Nx, Ny = sdom.shape
xmin = np.round(Nx//2 + xmin/Dstx).astype(int)
xmax = np.round(Nx//2 + xmax/Dstx).astype(int)
ymin = np.round(Ny//2 + ymin/Dsty).astype(int)
ymax = np.round(Ny//2 + ymax/Dsty).astype(int)
mask_lowres = np.ones(sdom.shape)
mask_lowres[xmin:xmax, ymin:ymax] = 0.
mask_lowres = ift.makeField(sdom, mask_lowres)
center_highres = (xmin + (xmax-xmin)//2 - Nx//2)*Dstx, (ymin + (ymax-ymin)//2 - Ny//2)*Dsty
print(np.array(center_highres) / rve.ARCMIN2RAD)
npix = np.array([(xmax-xmin)*oversampling, (ymax-ymin)*oversampling])
assert (npix % 2 == 0).all()
dst = np.array(sdom.distances) / oversampling
dom_highres = ift.makeDomain(dom[:3] + (ift.RGSpace(npix, dst),))
rve.assert_sky_domain(dom_highres)
np.testing.assert_allclose(sdom.total_volume - mask_lowres.s_integrate(),
dom_highres[-1].total_volume)
mask_lowres = ift.ContractionOperator(dom, (0, 1, 2)).adjoint(mask_lowres)
return dom_highres, mask_lowres, center_highres
def main():
_, ms = sys.argv
ms = list(rve.ms2observations_all(ms, "DATA"))
assert len(ms) == 1
ms = ms[0]
ms = ms.restrict_to_stokesi().to_double_precision()
dst = ms.nyquist_resolution() / oversampling_factor
sdom = ift.RGSpace([npix, npix], [dst, dst])
print(f"Field of view: {npix*dst/rve.DEG2RAD:.2} deg")
dom = rve.default_sky_domain(sdom=sdom)
assert all(dd == 1 for dd in dom.shape[:3])
squeeze = ift.ContractionOperator(dom, (0, 1, 2))
dom_highres, mask_lowres, center_highres = highres_domain_and_lowrew_mask(dom, oversampling_highres, "-0.94arcmin", "-0.78arcmin", "-0.40arcmin", "-0.28arcmin")
dom_highres, mask_lowres, center_highres = highres_domain_and_lowrew_mask(dom, oversampling_highres, "-0.2arcmin", "0.2arcmin", "-0.2arcmin", "0.2arcmin")
direction_highres = rve.Direction(center_highres, ms.direction.equinox) + ms.direction
R = rve.InterferometryResponse(ms, dom, do_wgridding, epsilon, nthreads=nthreads)
R_highres = rve.InterferometryResponse(ms, dom_highres, do_wgridding, epsilon,
center=center_highres, nthreads=nthreads)
dirty = get_dirty(ms, dom, ms.vis)
dirty_highres = get_dirty(ms, dom_highres, ms.vis, center=center_highres)
psf = get_psf(ms, dom)
psf_highres = get_psf(ms, dom, center=center_highres)
rve.ubik_tools.field2fits(dirty, "dirty.fits", ms.direction)
rve.ubik_tools.field2fits(dirty_highres, "dirty_highres.fits", direction_highres)
exit()
rve.ubik_tools.field2fits(psf, "psf.fits")
imajor = 0
residuals = ms.vis
model = ift.full(dom, 0.0)
while True: # Major cycle
print(f"Working on major iteration {imajor}")
model, err = major_cycle(ms, dom, residuals, model, squeeze(psf), minor_conv)
if err:
break
model_data = R(model)
residuals = ms.vis - model_data
rve.ubik_tools.field2fits(get_dirty(ms, dom, residuals), f"residual_iter{imajor}.fits")
rve.ubik_tools.field2fits(
get_dirty(ms, dom, model_data, weighting_scheme="uniform"),
f"restored_iter{imajor}.fits",
)
rve.ubik_tools.field2fits(model, f"model_iter{imajor}.fits")
# rve.plot.visualize_weighted_residuals(
# [ms],
# ift.SampleList([model]),
# imajor,
# ift.Operator.identity_operator(dom),
# None,
# output_directory=".",
# do_wgridding=do_wgridding,
# epsilon=epsilon,
# nthreads=nthreads,
# io=True,
# )
if imajor + 1 == nmajor:
break
imajor += 1
def find_peak(arr):
inds = np.unravel_index(np.argmax(np.abs(arr)), arr.shape)
peak = arr[inds]
return inds, peak
def major_cycle(ms, dom, residuals, model, psf, convergence_criteria):
_, _, _, sdom = dom
assert psf.domain == ift.makeDomain(sdom)
assert dom == model.domain
assert residuals.domain == ms.vis.domain
nminor = 10
dirty = get_dirty(ms, dom, residuals)
dirty = np.squeeze(dirty.val_rw())
assert dirty.shape == sdom.shape
nx, ny = dirty.shape
_, ref_flux = find_peak(dirty)
model = np.squeeze(model.val_rw())
iminor = 0
while True:
(indx, indy), flux = find_peak(dirty)
print(f"Flux / ref_flux: {flux / ref_flux * 100:.1f}%")
if iminor == convergence_criteria["max components"]:
break
if np.abs(flux) < np.abs(ref_flux) * convergence_criteria["rel first peak"]:
break
flux *= convergence_criteria["gain"]
# Update model
model[indx, indy] += flux
# Subtract component from dirty image
off = nx // 2 - indx
if off > 0:
psf_x = slice(off, None)
dirty_x = slice(None, -off)
elif off == 0:
psf_x = slice(None)
dirty_x = slice(None)
else:
psf_x = slice(None, off)
dirty_x = slice(-off, None)
off = ny // 2 - indy
if off > 0:
psf_y = slice(off, None)
dirty_y = slice(None, -off)
elif off == 0:
psf_y = slice(None)
dirty_y = slice(None)
else:
psf_y = slice(None, off)
dirty_y = slice(-off, None)
dirty[dirty_x, dirty_y] -= flux * psf.val[psf_x, psf_y]
iminor += 1
print("Major cycle finished")
print(f"Ref flux: {ref_flux}, {iminor} components")
return ift.makeField(dom, model.reshape(dom.shape)), 0
def get_psf(ms, dom, center=(0., 0.)):
psf = rve.dirty_image(
ms,
weighting_scheme,
dom,
do_wgridding,
epsilon,
nthreads=nthreads,
vis=ms.vis * 0 + 1,
center=(0., 0.)
)
return psf / psf.val.max()
def get_dirty(ms, dom, vis, weighting_scheme=weighting_scheme, center=(0., 0.)):
dirty = rve.dirty_image(
ms, weighting_scheme, dom, do_wgridding, epsilon, nthreads=nthreads, vis=vis, center=center
)
return dirty
if __name__ == "__main__":
main()
Loading