diff --git a/resolve/__init__.py b/resolve/__init__.py index 4a1ade5b438f0cae57c7cee0a4e78a7ded42e6c3..79b64ccfd1a6632629221bbfc1a899bd35be9671 100644 --- a/resolve/__init__.py +++ b/resolve/__init__.py @@ -26,3 +26,4 @@ from .polarization_matrix_exponential import * from .response import MfResponse, ResponseDistributor, StokesIResponse, SingleResponse from .simple_operators import * from .util import * +from .extra import mpi_load diff --git a/resolve/data/observation.py b/resolve/data/observation.py index 67682e24675b2015421602605ef88a7380100d03..0a9fe22862014482de0970cbf9f050f6371ffd83 100644 --- a/resolve/data/observation.py +++ b/resolve/data/observation.py @@ -9,7 +9,7 @@ import nifty8 as ift from .antenna_positions import AntennaPositions from ..constants import SPEEDOFLIGHT from .direction import Direction, Directions -from ..mpi import onlymaster, master +from ..mpi import onlymaster from .polarization import Polarization from ..util import compare_attributes, my_assert, my_assert_isinstance, my_asserteq @@ -354,30 +354,6 @@ class Observation(BaseObservation): direction, ) - @staticmethod - def split_data_file(data_path, ntask, target_folder, base_name, nwork, compress): - from os import makedirs - makedirs(target_folder, exist_ok=True) - - obs = Observation.load(data_path) - - for rank in range(ntask): - lo, hi = ift.utilities.shareRange(nwork, ntask, rank) - sliced_obs = obs.get_freqs_by_slice(slice(*(lo, hi))) - sliced_obs.save(f"{target_folder}/{base_name}_{rank}.npz", compress=compress) - - @staticmethod - def mpi_load(data_folder, base_name, full_data_set, nwork, comm=None, compress=False): - if master: - from os.path import isdir - if not isdir(data_folder): - Observation.split_data_file(full_data_set, comm.Get_size(), data_folder, base_name, nwork, compress) - if comm is None: - return Observation.load(full_data_set) - - comm.Barrier() - return Observation.load(f"{data_folder}/{base_name}_{comm.Get_rank()}.npz") - def flags_to_nan(self): if self.fraction_useful == 1.: return self diff --git a/resolve/extra.py b/resolve/extra.py new file mode 100644 index 0000000000000000000000000000000000000000..c707879dde6d8d6132cb7848a30e694375e66518 --- /dev/null +++ b/resolve/extra.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# Copyright(C) 2019-2021 Max-Planck-Society +# Author: Philipp Arras + +import nifty8 as ift + +from .mpi import master +from data.observation import Observation + + +def split_data_file(data_path, n_task, target_folder, base_name, n_work, compress): + from os import makedirs + makedirs(target_folder, exist_ok=True) + + obs = Observation.load(data_path) + + for rank in range(n_task): + lo, hi = ift.utilities.shareRange(n_work, n_task, rank) + sliced_obs = obs.get_freqs_by_slice(slice(*(lo, hi))) + sliced_obs.save(f"{target_folder}/{base_name}_{rank}.npz", compress=compress) + + +def mpi_load(data_folder, base_name, full_data_set, n_work, comm=None, compress=False): + if master: + from os.path import isdir + if not isdir(data_folder): + split_data_file(full_data_set, comm.Get_size(), data_folder, base_name, n_work, compress) + if comm is None: + return Observation.load(full_data_set) + + comm.Barrier() + return Observation.load(f"{data_folder}/{base_name}_{comm.Get_rank()}.npz")