Commit 7313728c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve MS reader

parent 7c9cf038
Pipeline #82060 passed with stages
in 16 minutes and 13 seconds
......@@ -20,7 +20,38 @@ import matplotlib.pyplot as plt
import numpy as np
def read_ms(name):
def get_indices(name):
from os.path import join
from casacore.tables import table
with table(join(name, 'POLARIZATION'), readonly=True, ack=False) as t:
pol = list(t.getcol('CORR_TYPE')[0])
if set(pol) <= set([5, 6, 7, 8]):
ind = [pol.index(5), pol.index(8)]
else:
ind = [pol.index(9), pol.index(12)]
return ind
def determine_weighting(t):
fullwgt = False
weightcol = "WEIGHT"
try:
t.getcol("WEIGHT_SPECTRUM", startrow=0, nrow=1)
weightcol = "WEIGHT_SPECTRUM"
fullwgt = True
except:
pass
return fullwgt, weightcol
def extra_checks(t):
if len(set(t.getcol('FIELD_ID'))) != 1:
raise RuntimeError
if len(set(t.getcol('DATA_DESC_ID'))) != 1:
raise RuntimeError
def read_ms_i(name):
# Assumptions:
# - Only one field
# - Only one spectral window
......@@ -31,74 +62,73 @@ def read_ms(name):
with table(join(name, 'SPECTRAL_WINDOW'), readonly=True, ack=False) as t:
freq = t.getcol('CHAN_FREQ')[0]
nchan = freq.shape[0]
with table(join(name, 'POLARIZATION'), readonly=True, ack=False) as t:
pol = list(t.getcol('CORR_TYPE')[0])
if set(pol) <= set([5, 6, 7, 8]):
ind = [pol.index(5), pol.index(8)]
else:
ind = [pol.index(9), pol.index(12)]
ind = get_indices(name)
with table(name, readonly=True, ack=False) as t:
fullwgt = 'WEIGHT_SPECTRUM' in t.colnames()
if len(set(t.getcol('FIELD_ID'))) != 1:
raise RuntimeError
if len(set(t.getcol('DATA_DESC_ID'))) != 1:
raise RuntimeError
uvw = t.getcol("UVW")
nrow = uvw.shape[0]
step = max(1, t.nrows()//100) # how many rows to read in every step
fullwgt, weightcol = determine_weighting(t)
extra_checks(t)
nrow = t.nrows()
active_rows = np.ones(nrow, dtype=np.bool)
active_channels = np.zeros(nchan, dtype=np.bool)
step = max(1, nrow//100) # how many rows to read in every step
# determine which subset of rows/channels we need to input
start = 0
vis = np.empty((nrow, nchan), dtype=np.complex64)
if fullwgt:
wgt = np.empty((nrow, nchan), dtype=np.float32)
else:
wgt = np.empty((nrow), dtype=np.float32)
flags = np.empty((nrow, nchan), dtype=np.bool)
while start < nrow:
stop = min(nrow, start+step)
tvis = t.getcol("DATA", startrow=start, nrow=stop-start)
ncorr = tvis.shape[2]
tvis = np.sum(tvis[:, :, ind], axis=2)
vis[start:stop, :] = tvis
if fullwgt:
twgt = t.getcol("WEIGHT_SPECTRUM", startrow=start, nrow=stop-start)[:, :, ind]
wgt[start:stop, :] = 1/np.sum(1/twgt, axis=2)
else:
twgt = t.getcol("WEIGHT", startrow=start, nrow=stop-start)[:, :, ind]
wgt[start:stop] = 1/np.sum(1/twgt, axis=1)
tflags = t.getcol('FLAG', startrow=start, nrow=stop-start)[:, :, ind]
flags[start:stop, :] = np.any(tflags.astype(np.bool), axis=2)
tflags = t.getcol('FLAG', startrow=start, nrow=stop-start)
ncorr = tflags.shape[2]
tflags = tflags[..., ind]
tflags = np.any(tflags.astype(np.bool), axis=-1)
twgt = t.getcol(weightcol, startrow=start, nrow=stop-start)[..., ind]
twgt = 1/np.sum(1/twgt, axis=-1)
tflags[twgt==0] = True
active_rows[start:stop] = np.invert(np.all(tflags, axis=-1))
active_channels = np.logical_or(active_channels, np.invert(np.all(tflags, axis=0)))
start = stop
# visibilities with weight 0 might as well be flagged
if fullwgt:
flags[wgt==0] = True
else:
flags[np.broadcast_to(wgt.reshape((-1,1)), vis.shape)==0] = True
nrealrows, nrealchan = np.sum(active_rows), np.sum(active_channels)
start, realstart = 0, 0
vis = np.empty((nrealrows, nrealchan), dtype=np.complex64)
wgtshp = (nrealrows, nrealchan) if fullwgt else (nrealrows,)
wgt = np.empty(wgtshp, dtype=np.float32)
flags = np.empty((nrealrows, nrealchan), dtype=np.bool)
while start < nrow:
stop = min(nrow, start+step)
realstop = realstart+np.sum(active_rows[start:stop])
if realstop == realstart:
start = stop
realstart = realstop
continue
tvis = t.getcol("DATA", startrow=start, nrow=stop-start)[..., ind]
tvis = np.sum(tvis, axis=-1)
tvis = tvis[active_rows[start:stop]][:, active_channels]
tflags = t.getcol('FLAG', startrow=start, nrow=stop-start)[..., ind]
tflags = np.any(tflags.astype(np.bool), axis=-1)
tflags = tflags[active_rows[start:stop]][:, active_channels]
twgt = t.getcol(weightcol, startrow=start, nrow=stop-start)[..., ind]
twgt = 1/np.sum(1/twgt, axis=-1)
twgt = twgt[active_rows[start:stop]]
if fullwgt:
twgt = twgt[:, active_channels]
tflags[twgt==0] = True
vis[realstart:realstop] = tvis
wgt[realstart:realstop] = twgt
flags[realstart:realstop] = tflags
start, realstart = stop, realstop
uvw = t.getcol("UVW")[active_rows]
print('# Rows: {}'.format(vis.shape[0]))
print('# Channels: {}'.format(vis.shape[1]))
print('# Rows: {} ({} fully flagged)'.format(nrow, nrow-vis.shape[0]))
print('# Channels: {} ({} fully flagged)'.format(nchan, nchan-vis.shape[1]))
print('# Correlations: {}'.format(ncorr))
print('Full weights' if fullwgt else 'Row-only weights')
print("{} % flagged".format(np.sum(flags)/flags.size*100))
# cut out unused rows/channels
rows_with_data = np.invert(np.all(flags, axis=1))
n_empty_rows = nrow-np.sum(rows_with_data)
print("Completely flagged rows: {}".format(n_empty_rows))
if n_empty_rows > 0:
uvw = uvw[rows_with_data,:]
vis = vis[rows_with_data,:]
wgt = wgt[rows_with_data,:] if fullwgt else wgt[rows_with_data]
flags = flags[rows_with_data,:]
channels_with_data = np.invert(np.all(flags, axis=0))
n_empty_channels = nchan-np.sum(channels_with_data)
print("Completely flagged channels: {}".format(n_empty_channels))
if n_empty_channels > 0:
freq = freq[channels_with_data]
vis = vis[:, channels_with_data]
wgt = wgt[:, channels_with_data] if fullwgt else wgt
flags = flags[:, channels_with_data]
nflagged = np.sum(flags) + (nrow-vis.shape[0])*nchan + (nchan-vis.shape[1])*nrow
print("{} % flagged".format(nflagged/(nrow*nchan)*100))
freq = freq[active_channels]
# blow up wgt to the right dimensions if necessary
if not fullwgt:
......@@ -112,13 +142,14 @@ def read_ms(name):
def main():
ms = '/home/martin/ms/supernovashell.55.7+3.4.spw0.ms'
ms = '/home/martin/ms/1052735056_cleaned.ms'
uvw, freq, vis, wgt, flags = read_ms(ms)
# ms, fov_deg = '/home/martin/ms/supernovashell.55.7+3.4.spw0.ms', 2.
# ms, fov_deg = '/home/martin/ms/1052736496-averaged.ms', 45.
ms, fov_deg = '/home/martin/ms/1052735056_cleaned.ms', 45.
uvw, freq, vis, wgt, flags = read_ms_i(ms)
npixdirty = 1200
DEG2RAD = np.pi/180
pixsize = 45/npixdirty*DEG2RAD
pixsize = fov_deg/npixdirty*DEG2RAD
nthreads = 2
epsilon = 1e-4
print('Start gridding...')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment