Skip to content
Snippets Groups Projects
Commit 6940587f authored by Philipp Arras's avatar Philipp Arras
Browse files

Plotting and export improvements

parent 4dccc027
No related branches found
No related tags found
1 merge request!39Draft: Clean experiments
import resolve as rve
import ducc0
import numpy as np
import nifty8 as ift
import sys
npix = 1000
npix = 2 * ducc0.fft.good_size(npix // 2)
oversampling_factor = 4
minor_conv = {"rel first peak": 0.2, "max components": 3000, "gain": 0.8}
nmajor = 10
weighting_scheme = "uniform"
weighting_scheme = "natural"
do_wgridding = True
epsilon = 1e-8
nthreads = 6
def main():
_, ms = sys.argv
ms = list(rve.ms2observations_all(ms, "DATA"))
assert len(ms) == 1
ms = ms[0]
ms = ms.restrict_to_stokesi().to_double_precision()
dst = ms.nyquist_resolution() / oversampling_factor
sdom = ift.RGSpace([npix, npix], [dst, dst])
print(f"Field of view: {npix*dst/rve.DEG2RAD:.2} deg")
dom = rve.default_sky_domain(sdom=sdom)
assert all(dd == 1 for dd in dom.shape[:3])
squeeze = ift.ContractionOperator(dom, (0, 1, 2))
R = rve.InterferometryResponse(ms, dom, do_wgridding, epsilon, nthreads=nthreads)
dirty = rve.dirty_image(ms, weighting_scheme, dom, do_wgridding, epsilon, nthreads=nthreads)
dirty = get_dirty(ms, dom, ms.vis)
psf = get_psf(ms, dom)
rve.ubik_tools.field2fits(dirty, "dirty.fits")
rve.ubik_tools.field2fits(psf, "psf.fits")
imajor = 0
residuals = ms.vis
model = ift.full(dom, 0.0)
while True: # Major cycle
print(f"Working on major iteration {imajor}")
model, err = major_cycle(ms, dom, residuals, model, squeeze(psf), minor_conv)
if err:
break
model_data = R(model)
residuals = ms.vis - model_data
rve.ubik_tools.field2fits(get_dirty(ms, dom, residuals), f"residual_iter{imajor}.fits")
rve.ubik_tools.field2fits(get_dirty(ms, dom, model_data), f"restored_iter{imajor}.fits")
rve.ubik_tools.field2fits(model, f"model_iter{imajor}.fits")
# rve.plot.visualize_weighted_residuals(
# [ms],
# ift.SampleList([model]),
# imajor,
# ift.Operator.identity_operator(dom),
# None,
# output_directory=".",
# do_wgridding=do_wgridding,
# epsilon=epsilon,
# nthreads=nthreads,
# io=True,
# )
if imajor + 1 == nmajor:
break
imajor += 1
def find_peak(arr):
inds = np.unravel_index(np.argmax(np.abs(arr)), arr.shape)
peak = arr[inds]
return inds, peak
def major_cycle(ms, dom, residuals, model, psf, convergence_criteria):
_, _, _, sdom = dom
assert psf.domain == ift.makeDomain(sdom)
assert dom == model.domain
assert residuals.domain == ms.vis.domain
nminor = 10
dirty = get_dirty(ms, dom, residuals)
dirty = np.squeeze(dirty.val_rw())
assert dirty.shape == sdom.shape
nx, ny = dirty.shape
_, ref_flux = find_peak(dirty)
model = np.squeeze(model.val_rw())
iminor = 0
while True:
(indx, indy), flux = find_peak(dirty)
print(f"Flux / ref_flux: {flux / ref_flux * 100:.1f}%")
if iminor == convergence_criteria["max components"]:
break
if np.abs(flux) < np.abs(ref_flux) * convergence_criteria["rel first peak"]:
break
flux *= convergence_criteria["gain"]
# Update model
model[indx, indy] += flux
# Subtract component from dirty image
off = nx // 2 - indx
if off > 0:
psf_x = slice(off, None)
dirty_x = slice(None, -off)
elif off == 0:
psf_x = slice(None)
dirty_x = slice(None)
else:
psf_x = slice(None, off)
dirty_x = slice(-off, None)
off = ny // 2 - indy
if off > 0:
psf_y = slice(off, None)
dirty_y = slice(None, -off)
elif off == 0:
psf_y = slice(None)
dirty_y = slice(None)
else:
psf_y = slice(None, off)
dirty_y = slice(-off, None)
dirty[dirty_x, dirty_y] -= flux * psf.val[psf_x, psf_y]
iminor += 1
print("Major cycle finished")
print(f"Ref flux: {ref_flux}, {iminor} components")
return ift.makeField(dom, model.reshape(dom.shape)), 0
def get_psf(ms, dom):
psf = rve.dirty_image(
ms,
weighting_scheme,
dom,
do_wgridding,
epsilon,
nthreads=nthreads,
vis=ms.vis * 0 + 1,
)
return psf / psf.val.max()
def get_dirty(ms, dom, vis):
dirty = rve.dirty_image(
ms, weighting_scheme, dom, do_wgridding, epsilon, nthreads=nthreads, vis=vis
)
return dirty
if __name__ == "__main__":
main()
...@@ -3,6 +3,7 @@ from .version import __version__ ...@@ -3,6 +3,7 @@ from .version import __version__
from .logger import logger from .logger import logger
from . import ubik_tools from . import ubik_tools
from . import plot
from .calibration import CalibrationDistributor, calibration_distribution from .calibration import CalibrationDistributor, calibration_distribution
from .config_utils import * from .config_utils import *
from .constants import * from .constants import *
......
from .baseline_histogram import *
from .sky import *
...@@ -25,7 +25,7 @@ from ..util import assert_sky_domain ...@@ -25,7 +25,7 @@ from ..util import assert_sky_domain
from ..data.observation import Observation from ..data.observation import Observation
def field2fits(field, file_name, observations=[], header_override={}): def field2fits(field, file_name, observations=[]):
import astropy.io.fits as pyfits import astropy.io.fits as pyfits
from astropy.time import Time from astropy.time import Time
...@@ -55,8 +55,7 @@ def field2fits(field, file_name, observations=[], header_override={}): ...@@ -55,8 +55,7 @@ def field2fits(field, file_name, observations=[], header_override={}):
# h["DATE-OBS"] = '2019-05-08T20:32:19.1' # TEMPORARY # h["DATE-OBS"] = '2019-05-08T20:32:19.1' # TEMPORARY
# #h["TELESCOPE"] = "MEERKAT" # TEMPORARY # #h["TELESCOPE"] = "MEERKAT" # TEMPORARY
# h["OBSERVER"] = "xxx" # TEMPORARY # h["OBSERVER"] = "xxx" # TEMPORARY
h["OBSRA"] = direction.phase_center[0] * 180 / np.pi if direction is not None else 0.0 # h["OBJECT"] = "xxx" # TEMPORARY
h["OBSDEC"] = direction.phase_center[1] * 180 / np.pi if direction is not None else 0.0
# h["SPECSYS"] = "TOPOCENT" # h["SPECSYS"] = "TOPOCENT"
h["CTYPE1"] = "RA---SIN" h["CTYPE1"] = "RA---SIN"
h["CRVAL1"] = direction.phase_center[0] * 180 / np.pi if direction is not None else 0.0 h["CRVAL1"] = direction.phase_center[0] * 180 / np.pi if direction is not None else 0.0
...@@ -95,10 +94,7 @@ def field2fits(field, file_name, observations=[], header_override={}): ...@@ -95,10 +94,7 @@ def field2fits(field, file_name, observations=[], header_override={}):
h["DATE-MAP"] = Time(time.time(), format="unix").iso.split()[0] h["DATE-MAP"] = Time(time.time(), format="unix").iso.split()[0]
if direction is not None: if direction is not None:
h["EQUINOX"] = direction.equinox h["EQUINOX"] = direction.equinox
multiple_times = tdom.size > 1
for kk, vv in header_override.items():
h[kk] = vv
for t_ind, t_val in enumerate(tdom.coordinates): for t_ind, t_val in enumerate(tdom.coordinates):
val = field.val[:, t_ind] # Select time val = field.val[:, t_ind] # Select time
val = np.transpose(val, (0, 1, 3, 2)) # Switch spatial axes val = np.transpose(val, (0, 1, 3, 2)) # Switch spatial axes
...@@ -107,7 +103,8 @@ def field2fits(field, file_name, observations=[], header_override={}): ...@@ -107,7 +103,8 @@ def field2fits(field, file_name, observations=[], header_override={}):
hdu = pyfits.PrimaryHDU(val, header=h) hdu = pyfits.PrimaryHDU(val, header=h)
hdulist = pyfits.HDUList([hdu]) hdulist = pyfits.HDUList([hdu])
base, ext = splitext(file_name) base, ext = splitext(file_name)
if len(tdom.coordinates) == 1: if multiple_times:
hdulist.writeto(base + ext, overwrite=True) fname = base + f"timebin{t_val}" + ext
else: else:
hdulist.writeto(base + f"time{t_val}" + ext, overwrite=True) fname = base + ext
hdulist.writeto(fname, overwrite=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment