diff --git a/resolve/data/observation.py b/resolve/data/observation.py index 0a9fe22862014482de0970cbf9f050f6371ffd83..96419f48d76041639f2f0cfa93e3517d08ba60bb 100644 --- a/resolve/data/observation.py +++ b/resolve/data/observation.py @@ -332,27 +332,20 @@ class Observation(BaseObservation): if val.size == 0: val = None antpos.append(val) + antpos = AntennaPositions.from_list(antpos) pol = Polarization.from_list(dct["polarization"]) direction = Direction.from_list(dct["direction"]) - if lo_hi_index is None: - vis = dct["vis"] - weight = dct["weight"] - freq = dct["freq"] - else: + vis = dct["vis"] + wgt = dct["weight"] + freq = dct["freq"] + if lo_hi_index is not None: slc = slice(*lo_hi_index) # Convert view into its own array - vis = dct["vis"][..., slc].copy() - weight = dct["weight"][..., slc].copy() - freq = dct["freq"][slc].copy() + vis = vis[..., slc].copy() + wgt = wgt[..., slc].copy() + freq = freq[slc].copy() del dct - return Observation( - AntennaPositions.from_list(antpos), - vis, - weight, - pol, - freq, - direction, - ) + return Observation(antpos, vis, wgt, pol, freq, direction) def flags_to_nan(self): if self.fraction_useful == 1.: @@ -375,32 +368,27 @@ class Observation(BaseObservation): if comm is None: local_imaging_bands = range(n_imaging_bands) else: - local_imaging_bands = range( - *ift.utilities.shareRange( - n_imaging_bands, comm.Get_size(), comm.Get_rank() - ) - ) + lo, hi = ift.utilities.shareRange( n_imaging_bands, comm.Get_size(), comm.Get_rank()) + local_imaging_bands = range(lo, hi) full_obs = Observation.load(file_name) - obs_list = [ - full_obs.get_freqs_by_slice( - slice(*ift.utilities.shareRange(len(global_freqs), n_imaging_bands, ii)) - ) - for ii in local_imaging_bands - ] + obs_list = [] + for ii in local_imaging_bands: + slc = slice(*ift.utilities.shareRange(len(global_freqs), n_imaging_bands, ii)) + obs_list.append(full_obs.get_freqs_by_slice(slc)) nu0 = global_freqs.mean() return obs_list, nu0 - def __getitem__(self, slc): - return Observation( - self._antpos[slc], - self._vis[:, slc], - self._weight[:, slc], - self._polarization, - self._freq, - self._direction, - ) - - def get_freqs(self, frequency_list): + def __getitem__(self, slc, copy=False): + ap = self._antpos[slc] + vis = self._vis[slc] + wgt = self._weight[:, slc] + if copy: + ap = ap.copy() + vis = vis.copy() + wgt = wgt.copy() + return Observation(ap, vis, wgt, self._polarization, self._freq, self._direction) + + def get_freqs(self, frequency_list, copy=False): """Return observation that contains a subset of the present frequencies Parameters @@ -410,17 +398,17 @@ class Observation(BaseObservation): """ mask = np.zeros(self.nfreq, dtype=bool) mask[frequency_list] = 1 - return self.get_freqs_by_slice(mask) - - def get_freqs_by_slice(self, slc): - return Observation( - self._antpos, - self._vis[..., slc], - self._weight[..., slc], - self._polarization, - self._freq[slc], - self._direction, - ) + return self.get_freqs_by_slice(mask, copy) + + def get_freqs_by_slice(self, slc, copy=False): + vis = self._vis[..., slc] + wgt = self._weight[..., slc] + freq = self._freq[slc] + if copy: + vis = vis.copy() + wgt = wgt.copy() + freq = freq.copy() + return Observation( self._antpos, vis, wgt, self._polarization, freq, self._direction) def average_stokesi(self): my_asserteq(self._vis.shape[0], 2)