Skip to content
Snippets Groups Projects
Commit 6940587f authored by Philipp Arras's avatar Philipp Arras
Browse files

Plotting and export improvements

parent 4dccc027
Branches
Tags
1 merge request!39Draft: Clean experiments
import resolve as rve
import ducc0
import numpy as np
import nifty8 as ift
import sys
npix = 1000
npix = 2 * ducc0.fft.good_size(npix // 2)
oversampling_factor = 4
minor_conv = {"rel first peak": 0.2, "max components": 3000, "gain": 0.8}
nmajor = 10
weighting_scheme = "uniform"
weighting_scheme = "natural"
do_wgridding = True
epsilon = 1e-8
nthreads = 6
def main():
_, ms = sys.argv
ms = list(rve.ms2observations_all(ms, "DATA"))
assert len(ms) == 1
ms = ms[0]
ms = ms.restrict_to_stokesi().to_double_precision()
dst = ms.nyquist_resolution() / oversampling_factor
sdom = ift.RGSpace([npix, npix], [dst, dst])
print(f"Field of view: {npix*dst/rve.DEG2RAD:.2} deg")
dom = rve.default_sky_domain(sdom=sdom)
assert all(dd == 1 for dd in dom.shape[:3])
squeeze = ift.ContractionOperator(dom, (0, 1, 2))
R = rve.InterferometryResponse(ms, dom, do_wgridding, epsilon, nthreads=nthreads)
dirty = rve.dirty_image(ms, weighting_scheme, dom, do_wgridding, epsilon, nthreads=nthreads)
dirty = get_dirty(ms, dom, ms.vis)
psf = get_psf(ms, dom)
rve.ubik_tools.field2fits(dirty, "dirty.fits")
rve.ubik_tools.field2fits(psf, "psf.fits")
imajor = 0
residuals = ms.vis
model = ift.full(dom, 0.0)
while True: # Major cycle
print(f"Working on major iteration {imajor}")
model, err = major_cycle(ms, dom, residuals, model, squeeze(psf), minor_conv)
if err:
break
model_data = R(model)
residuals = ms.vis - model_data
rve.ubik_tools.field2fits(get_dirty(ms, dom, residuals), f"residual_iter{imajor}.fits")
rve.ubik_tools.field2fits(get_dirty(ms, dom, model_data), f"restored_iter{imajor}.fits")
rve.ubik_tools.field2fits(model, f"model_iter{imajor}.fits")
# rve.plot.visualize_weighted_residuals(
# [ms],
# ift.SampleList([model]),
# imajor,
# ift.Operator.identity_operator(dom),
# None,
# output_directory=".",
# do_wgridding=do_wgridding,
# epsilon=epsilon,
# nthreads=nthreads,
# io=True,
# )
if imajor + 1 == nmajor:
break
imajor += 1
def find_peak(arr):
inds = np.unravel_index(np.argmax(np.abs(arr)), arr.shape)
peak = arr[inds]
return inds, peak
def major_cycle(ms, dom, residuals, model, psf, convergence_criteria):
_, _, _, sdom = dom
assert psf.domain == ift.makeDomain(sdom)
assert dom == model.domain
assert residuals.domain == ms.vis.domain
nminor = 10
dirty = get_dirty(ms, dom, residuals)
dirty = np.squeeze(dirty.val_rw())
assert dirty.shape == sdom.shape
nx, ny = dirty.shape
_, ref_flux = find_peak(dirty)
model = np.squeeze(model.val_rw())
iminor = 0
while True:
(indx, indy), flux = find_peak(dirty)
print(f"Flux / ref_flux: {flux / ref_flux * 100:.1f}%")
if iminor == convergence_criteria["max components"]:
break
if np.abs(flux) < np.abs(ref_flux) * convergence_criteria["rel first peak"]:
break
flux *= convergence_criteria["gain"]
# Update model
model[indx, indy] += flux
# Subtract component from dirty image
off = nx // 2 - indx
if off > 0:
psf_x = slice(off, None)
dirty_x = slice(None, -off)
elif off == 0:
psf_x = slice(None)
dirty_x = slice(None)
else:
psf_x = slice(None, off)
dirty_x = slice(-off, None)
off = ny // 2 - indy
if off > 0:
psf_y = slice(off, None)
dirty_y = slice(None, -off)
elif off == 0:
psf_y = slice(None)
dirty_y = slice(None)
else:
psf_y = slice(None, off)
dirty_y = slice(-off, None)
dirty[dirty_x, dirty_y] -= flux * psf.val[psf_x, psf_y]
iminor += 1
print("Major cycle finished")
print(f"Ref flux: {ref_flux}, {iminor} components")
return ift.makeField(dom, model.reshape(dom.shape)), 0
def get_psf(ms, dom):
psf = rve.dirty_image(
ms,
weighting_scheme,
dom,
do_wgridding,
epsilon,
nthreads=nthreads,
vis=ms.vis * 0 + 1,
)
return psf / psf.val.max()
def get_dirty(ms, dom, vis):
dirty = rve.dirty_image(
ms, weighting_scheme, dom, do_wgridding, epsilon, nthreads=nthreads, vis=vis
)
return dirty
if __name__ == "__main__":
main()
...@@ -3,6 +3,7 @@ from .version import __version__ ...@@ -3,6 +3,7 @@ from .version import __version__
from .logger import logger from .logger import logger
from . import ubik_tools from . import ubik_tools
from . import plot
from .calibration import CalibrationDistributor, calibration_distribution from .calibration import CalibrationDistributor, calibration_distribution
from .config_utils import * from .config_utils import *
from .constants import * from .constants import *
... ...
......
from .baseline_histogram import *
from .sky import *
...@@ -25,7 +25,7 @@ from ..util import assert_sky_domain ...@@ -25,7 +25,7 @@ from ..util import assert_sky_domain
from ..data.observation import Observation from ..data.observation import Observation
def field2fits(field, file_name, observations=[], header_override={}): def field2fits(field, file_name, observations=[]):
import astropy.io.fits as pyfits import astropy.io.fits as pyfits
from astropy.time import Time from astropy.time import Time
...@@ -55,8 +55,7 @@ def field2fits(field, file_name, observations=[], header_override={}): ...@@ -55,8 +55,7 @@ def field2fits(field, file_name, observations=[], header_override={}):
# h["DATE-OBS"] = '2019-05-08T20:32:19.1' # TEMPORARY # h["DATE-OBS"] = '2019-05-08T20:32:19.1' # TEMPORARY
# #h["TELESCOPE"] = "MEERKAT" # TEMPORARY # #h["TELESCOPE"] = "MEERKAT" # TEMPORARY
# h["OBSERVER"] = "xxx" # TEMPORARY # h["OBSERVER"] = "xxx" # TEMPORARY
h["OBSRA"] = direction.phase_center[0] * 180 / np.pi if direction is not None else 0.0 # h["OBJECT"] = "xxx" # TEMPORARY
h["OBSDEC"] = direction.phase_center[1] * 180 / np.pi if direction is not None else 0.0
# h["SPECSYS"] = "TOPOCENT" # h["SPECSYS"] = "TOPOCENT"
h["CTYPE1"] = "RA---SIN" h["CTYPE1"] = "RA---SIN"
h["CRVAL1"] = direction.phase_center[0] * 180 / np.pi if direction is not None else 0.0 h["CRVAL1"] = direction.phase_center[0] * 180 / np.pi if direction is not None else 0.0
...@@ -95,10 +94,7 @@ def field2fits(field, file_name, observations=[], header_override={}): ...@@ -95,10 +94,7 @@ def field2fits(field, file_name, observations=[], header_override={}):
h["DATE-MAP"] = Time(time.time(), format="unix").iso.split()[0] h["DATE-MAP"] = Time(time.time(), format="unix").iso.split()[0]
if direction is not None: if direction is not None:
h["EQUINOX"] = direction.equinox h["EQUINOX"] = direction.equinox
multiple_times = tdom.size > 1
for kk, vv in header_override.items():
h[kk] = vv
for t_ind, t_val in enumerate(tdom.coordinates): for t_ind, t_val in enumerate(tdom.coordinates):
val = field.val[:, t_ind] # Select time val = field.val[:, t_ind] # Select time
val = np.transpose(val, (0, 1, 3, 2)) # Switch spatial axes val = np.transpose(val, (0, 1, 3, 2)) # Switch spatial axes
...@@ -107,7 +103,8 @@ def field2fits(field, file_name, observations=[], header_override={}): ...@@ -107,7 +103,8 @@ def field2fits(field, file_name, observations=[], header_override={}):
hdu = pyfits.PrimaryHDU(val, header=h) hdu = pyfits.PrimaryHDU(val, header=h)
hdulist = pyfits.HDUList([hdu]) hdulist = pyfits.HDUList([hdu])
base, ext = splitext(file_name) base, ext = splitext(file_name)
if len(tdom.coordinates) == 1: if multiple_times:
hdulist.writeto(base + ext, overwrite=True) fname = base + f"timebin{t_val}" + ext
else: else:
hdulist.writeto(base + f"time{t_val}" + ext, overwrite=True) fname = base + ext
hdulist.writeto(fname, overwrite=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment