diff --git a/resolve/data/observation.py b/resolve/data/observation.py index b4fe3301062a6d1f032b8626ceb1d84ac24b7963..41e459f57054e8ee56a4ebc24b343008e4b2dcb2 100644 --- a/resolve/data/observation.py +++ b/resolve/data/observation.py @@ -9,7 +9,7 @@ import nifty8 as ift from .antenna_positions import AntennaPositions from ..constants import SPEEDOFLIGHT from .direction import Direction, Directions -from ..mpi import onlymaster +from ..mpi import onlymaster, master from .polarization import Polarization from ..util import compare_attributes, my_assert, my_assert_isinstance, my_asserteq @@ -354,6 +354,29 @@ class Observation(BaseObservation): direction, ) + @staticmethod + def split_data_file(self, data_path, ntask, target_folder, base_name, nwork): + from os import makedirs + makedirs(target_folder, exist_ok=True) + + for rank in range(ntask): + lo, hi = ift.utilities.shareRange(nwork, ntask, rank) + obs = Observation.load(data_path, (lo, hi)) + obs.save(f"{target_folder}/{base_name}_{rank}.npz") + + @staticmethod + def mpi_load(self, data_folder, base_name, full_data_set, nwork, comm=None): + if master: + from os.path import isdir + if not isdir(data_folder): + Observation.split_data_file(full_data_set, comm.Get_size, data_folder, base_name, nwork) + if comm is None: + return Observation.load(full_data_set) + + comm.Barrier() + return Observation.load(f"{data_folder}/{base_name}_{comm.Get_rank()}.npz") + + def flags_to_nan(self): if self.fraction_useful == 1.: return self