From 80e87bd08a45b10adb14b360437fc1857edfe1e1 Mon Sep 17 00:00:00 2001
From: Simon Ding <ding@mpa-garching.mpg.de>
Date: Thu, 2 Sep 2021 16:15:16 +0200
Subject: [PATCH] introducing mpi loading

---
 resolve/data/observation.py | 25 ++++++++++++++++++++++++-
 1 file changed, 24 insertions(+), 1 deletion(-)

diff --git a/resolve/data/observation.py b/resolve/data/observation.py
index b4fe3301..41e459f5 100644
--- a/resolve/data/observation.py
+++ b/resolve/data/observation.py
@@ -9,7 +9,7 @@ import nifty8 as ift
 from .antenna_positions import AntennaPositions
 from ..constants import SPEEDOFLIGHT
 from .direction import Direction, Directions
-from ..mpi import onlymaster
+from ..mpi import onlymaster, master
 from .polarization import Polarization
 from ..util import compare_attributes, my_assert, my_assert_isinstance, my_asserteq
 
@@ -354,6 +354,29 @@ class Observation(BaseObservation):
             direction,
         )
 
+    @staticmethod
+    def split_data_file(self, data_path, ntask, target_folder, base_name, nwork):
+        from os import makedirs
+        makedirs(target_folder, exist_ok=True)
+
+        for rank in range(ntask):
+            lo, hi = ift.utilities.shareRange(nwork, ntask, rank)
+            obs = Observation.load(data_path, (lo, hi))
+            obs.save(f"{target_folder}/{base_name}_{rank}.npz")
+
+    @staticmethod
+    def mpi_load(self, data_folder, base_name, full_data_set, nwork, comm=None):
+        if master:
+            from os.path import isdir
+            if not isdir(data_folder):
+                Observation.split_data_file(full_data_set, comm.Get_size, data_folder, base_name, nwork)
+            if comm is None:
+                return Observation.load(full_data_set)
+
+        comm.Barrier()
+        return Observation.load(f"{data_folder}/{base_name}_{comm.Get_rank()}.npz")
+
+
     def flags_to_nan(self):
         if self.fraction_useful == 1.:
             return self
-- 
GitLab