Commit a9cf3437 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

better weight handling

parent ff01865b
Pipeline #82006 passed with stages
in 16 minutes and 16 seconds
......@@ -38,16 +38,20 @@ def read_ms(name):
else:
ind = [pol.index(9), pol.index(12)]
with table(name, readonly=True, ack=False) as t:
fullwgt = 'WEIGHT_SPECTRUM' in t.colnames()
if len(set(t.getcol('FIELD_ID'))) != 1:
raise RuntimeError
if len(set(t.getcol('DATA_DESC_ID'))) != 1:
raise RuntimeError
uvw = t.getcol("UVW")
nrow = uvw.shape[0]
step = 1000 # how many rows to read in every step
step = max(1, t.nrows()//100) # how many rows to read in every step
start = 0
vis = np.empty((nrow, nchan), dtype=np.complex64)
wgt = np.empty((nrow, nchan), dtype=np.float32)
if fullwgt:
wgt = np.empty((nrow, nchan), dtype=np.float32)
else:
wgt = np.empty((nrow), dtype=np.float32)
flags = np.empty((nrow, nchan), dtype=np.bool)
while start < nrow:
stop = min(nrow, start+step)
......@@ -55,20 +59,26 @@ def read_ms(name):
ncorr = tvis.shape[2]
tvis = np.sum(tvis[:, :, ind], axis=2)
vis[start:stop, :] = tvis
twgt = t.getcol("WEIGHT", startrow=start, nrow=stop-start)
twgt = 1/np.sum(1/twgt, axis=1)
wgt[start:stop, :] = np.repeat(twgt[:, None], len(freq), axis=1)
if fullwgt:
twgt = t.getcol("WEIGHT_SPECTRUM", startrow=start, nrow=stop-start)
wgt[start:stop, :] = 1/np.sum(1/twgt, axis=2)
else:
twgt = t.getcol("WEIGHT", startrow=start, nrow=stop-start)
wgt[start:stop] = 1/np.sum(1/twgt, axis=1)
tflags = t.getcol('FLAG', startrow=start, nrow=stop-start)
flags[start:stop, :] = np.any(tflags.astype(np.bool), axis=2)
start = stop
# flagged visibilities get weight 0
wgt[flags] = 0
# visibilities with weight 0 might as well be flagged
flags[wgt==0] = True
if fullwgt:
flags[wgt==0] = True
else:
flags[np.broadcast_to(wgt.reshape((-1,1)), vis.shape)==0] = True
print('# Rows: {}'.format(vis.shape[0]))
print('# Channels: {}'.format(vis.shape[1]))
print('# Correlations: {}'.format(ncorr))
print('Full weights' if fullwgt else 'Row-only weights')
print("{} % flagged".format(np.sum(flags)/flags.size*100))
# cut out unused rows/channels
......@@ -78,20 +88,26 @@ def read_ms(name):
if n_empty_rows > 0:
uvw = uvw[rows_with_data,:]
vis = vis[rows_with_data,:]
wgt = wgt[rows_with_data,:]
wgt = wgt[rows_with_data,:] if fullwgt else wgt[rows_with_data]
flags = flags[rows_with_data,:]
channels_with_data = np.invert(np.all(flags, axis=0))
n_empty_channels = nchan-np.sum(channels_with_data)
print("Completely flagged channels: {}".format(n_empty_channels))
if n_empty_channels > 0:
freq = freq[channels_with_data]
vis = vis[:, channels_with_data]
wgt = wgt[:, channels_with_data]
wgt = wgt[:, channels_with_data] if fullwgt else wgt
flags = flags[:, channels_with_data]
# blow up wgt to the right dimensions if necessary
if not fullwgt:
wgt = np.broadcast_to(wgt.reshape((-1,1)), vis.shape)
return (np.ascontiguousarray(uvw),
np.ascontiguousarray(freq),
np.ascontiguousarray(vis),
np.ascontiguousarray(wgt),
np.ascontiguousarray(wgt) if fullwgt else wgt,
1-flags.astype(np.uint8))
......@@ -108,7 +124,6 @@ def main():
print('Start gridding...')
do_wstacking = True
wgt = np.where(np.abs(vis)==0, 0, wgt)
t0 = time()
dirty = wgridder.ms2dirty(uvw, freq, vis, wgt, npixdirty, npixdirty, pixsize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment