Skip to content
Snippets Groups Projects
Commit 6e095fcc authored by Philipp Arras's avatar Philipp Arras
Browse files

Add baseline functionality

parent 06107a95
No related branches found
No related tags found
1 merge request!25Calibration
Pipeline #106388 passed
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Copyright(C) 2019-2021 Max-Planck-Society # Copyright(C) 2019-2021 Max-Planck-Society
# Author: Philipp Arras # Author: Philipp Arras
import nifty8 as ift
import numpy as np import numpy as np
from ..util import (compare_attributes, my_assert, my_assert_isinstance, from ..util import (compare_attributes, my_assert, my_assert_isinstance,
...@@ -98,7 +99,7 @@ class AntennaPositions: ...@@ -98,7 +99,7 @@ class AntennaPositions:
def ant2(self): def ant2(self):
return self._ant2 return self._ant2
def extract_baseline(self, antenna1, antenna2, data): def extract_baseline(self, antenna1, antenna2, field):
"""Extract data that belongs to a given baseline. """Extract data that belongs to a given baseline.
Parameters Parameters
...@@ -107,13 +108,13 @@ class AntennaPositions: ...@@ -107,13 +108,13 @@ class AntennaPositions:
Antenna index of the first antenna that is selected. Antenna index of the first antenna that is selected.
antenna2: int antenna2: int
Antenna index of the second antenna that is selected. Antenna index of the second antenna that is selected.
data: numpy.ndarray field: nifty8.Field
Data array. Shape `(n_pol, n_row, n_freq)`. `n_row` must equal Data field. Shape `(n_pol, n_row, n_freq)`. `n_row` must equal
`len(self)`. `len(self)`.
Returns Returns
------- -------
numpy.ndarray nifty8.Field
Entries of `data` that correspond to the selected baseline. Shape Entries of `data` that correspond to the selected baseline. Shape
`(n_pol, n_time, n_freq)`. `(n_pol, n_time, n_freq)`.
""" """
...@@ -121,15 +122,15 @@ class AntennaPositions: ...@@ -121,15 +122,15 @@ class AntennaPositions:
raise RuntimeError("This algorithm assumes ant1<ant2.") raise RuntimeError("This algorithm assumes ant1<ant2.")
ut = np.sort(np.array(list(self.unique_times()))) ut = np.sort(np.array(list(self.unique_times())))
npol, nrow, nfreq = data.shape npol, nrow, nfreq = field.shape
my_asserteq(nrow, len(self)) my_asserteq(nrow, len(self))
my_assert_isinstance(data, np.ndarray) my_assert_isinstance(field, ift.Field)
my_assert_isinstance(antenna1, antenna2, int) my_assert_isinstance(antenna1, antenna2, int)
# Select by antenna labels # Select by antenna labels
ind = np.logical_and(self.ant1 == antenna1, ind = np.logical_and(self.ant1 == antenna1,
self.ant2 == antenna2) self.ant2 == antenna2)
data = data[:, ind] data = field.val[:, ind]
tt = self.time[ind] tt = self.time[ind]
# Sort by time # Sort by time
...@@ -141,7 +142,9 @@ class AntennaPositions: ...@@ -141,7 +142,9 @@ class AntennaPositions:
out = np.empty((npol, ut.size, nfreq), dtype=data.dtype) out = np.empty((npol, ut.size, nfreq), dtype=data.dtype)
out[:] = np.nan out[:] = np.nan
out[:, np.searchsorted(ut, tt)] = data out[:, np.searchsorted(ut, tt)] = data
return out elif np.array_equal(tt, ut):
if np.array_equal(tt, ut): out = data
return data else:
raise RuntimeError raise RuntimeError
dom = field.domain[0], ift.UnstructuredDomain(out.shape[1]), field.domain[2]
return ift.makeField(dom, out)
...@@ -103,6 +103,9 @@ class BaseObservation: ...@@ -103,6 +103,9 @@ class BaseObservation:
non-flagged data points from a field defined on `self.vis.domain`.""" non-flagged data points from a field defined on `self.vis.domain`."""
return ift.MaskOperator(self.flags) return ift.MaskOperator(self.flags)
def flags_to_nan(self):
raise NotImplementedError
def max_snr(self): def max_snr(self):
"""float: Maximum signal-to-noise ratio.""" """float: Maximum signal-to-noise ratio."""
snr = (self.vis * self.weight.sqrt()).abs() snr = (self.vis * self.weight.sqrt()).abs()
...@@ -342,6 +345,16 @@ class Observation(BaseObservation): ...@@ -342,6 +345,16 @@ class Observation(BaseObservation):
direction, direction,
) )
def flags_to_nan(self):
if self.fraction_useful == 1.:
return self
vis = self._vis.copy()
weight = self._weight.copy()
vis[self.flags.val] = np.nan
weight[self.flags.val] = np.nan
return Observation(self._antpos, vis, weight, self._pol, self._freq,
self._direction)
@staticmethod @staticmethod
def load_mf(file_name, n_imaging_bands, comm=None): def load_mf(file_name, n_imaging_bands, comm=None):
if comm is not None: if comm is not None:
......
...@@ -77,3 +77,14 @@ def imshow(arr, ax=None, **kwargs): ...@@ -77,3 +77,14 @@ def imshow(arr, ax=None, **kwargs):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
return ax.imshow(arr.T, origin="lower", **kwargs) return ax.imshow(arr.T, origin="lower", **kwargs)
def rows_to_baselines(antenna_positions, data_field):
ua = antenna_positions.unique_antennas()
my_assert(np.all(antenna_positions.ant1 < antenna_positions.ant2))
res = {}
for iant1 in range(len(ua)):
for iant2 in range(iant1+1, len(ua)):
res[f"{iant1}-{iant2}"] = antenna_positions.extract_baseline(iant1, iant2, data_field)
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment