From f48ace31286459b6b67fde1e137553e7f6b13c5f Mon Sep 17 00:00:00 2001
From: Philipp Schubert <p.schubert@stud.uni-heidelberg.de>
Date: Wed, 25 Nov 2020 14:07:42 +0100
Subject: [PATCH] improve memory usage in
 `_aggregate_segmentation_object_mappings_thread`

---
 scripts/dataset_specific/start_j0251_gce.py |  5 ++--
 scripts/misc/timings.py                     | 26 +++++++++---------
 syconn/exec/exec_init.py                    |  2 +-
 syconn/proc/ssd_proc.py                     | 29 +++++++++++---------
 syconn/reps/segmentation.py                 | 12 ++++++---
 syconn/reps/segmentation_helper.py          | 30 ++++++++++++++++++---
 6 files changed, 69 insertions(+), 35 deletions(-)

diff --git a/scripts/dataset_specific/start_j0251_gce.py b/scripts/dataset_specific/start_j0251_gce.py
index 067567a2..839941b4 100755
--- a/scripts/dataset_specific/start_j0251_gce.py
+++ b/scripts/dataset_specific/start_j0251_gce.py
@@ -29,7 +29,8 @@ if __name__ == '__main__':
     scale = np.array([10, 10, 25])
     number_of_nodes = 24
     node_states = nodestates_slurm()
-    node_state = next(iter(node_states.values()))
+    node_state = next(
+        iter(node_states.values()))
     exclude_nodes = []
     for nk in list(node_states.keys())[number_of_nodes:]:
         exclude_nodes.append(nk)
@@ -41,7 +42,7 @@ if __name__ == '__main__':
     ngpus_per_node = 2  # node_state currently does not contain the number of gpus for 'gres' resource
     shape_j0251 = np.array([27119, 27350, 15494])
     # 10.5* for 4.9, *9 for 3.13, *7.5 for 1.81, *6 for 0.927, *4.5 for 0.391, *3 for 0.115 TVx
-    cube_size = (np.array([2048, 2048, 1024]) * 6).astype(np.int)
+    cube_size = (np.array([2048, 2048, 1024]) * 7.5).astype(np.int)
     # all for 10 TVx
     cube_offset = ((shape_j0251 - cube_size) // 2).astype(np.int)
     cube_of_interest_bb = np.array([cube_offset, cube_offset + cube_size], dtype=np.int)
diff --git a/scripts/misc/timings.py b/scripts/misc/timings.py
index 9196e3c1..7a42584b 100755
--- a/scripts/misc/timings.py
+++ b/scripts/misc/timings.py
@@ -20,10 +20,11 @@ palette_ident = 'colorblind'
 
 def get_speed_plots(base_dir):
     sns.set_style("ticks", {"xtick.major.size": 20, "ytick.major.size": 20})
-    wds = glob.glob('/mnt/example_runs/j0251_*')
+    wds = glob.glob(f'{base_dir}/j0251_*')
+    assert len(wds) > 0
     base_dir = base_dir + '/timings/'
     log = initialize_logging(f'speed_plots', log_dir=base_dir)
-    log.info(f'Creating timing plots in base directory "{base_dir}".')
+    log.info(f'Creating speed plots in base directory "{base_dir}".')
     os.makedirs(base_dir, exist_ok=True)
     res_dc = {'time': [], 'step': [], 'datasize[mm3]': [], 'datasize[GVx]': [],
               'speed[mm3]': [], 'speed[GVx]': []}
@@ -53,7 +54,6 @@ def get_speed_plots(base_dir):
                 vol_nvox = ft.dataset_nvoxels['neuron']
             res_dc['speed[mm3]'].append(vol_mm3 / dt)
             res_dc['speed[GVx]'].append(vol_nvox / dt)
-    assert len(wds) > 0
     palette = sns.color_palette(n_colors=len(np.unique(res_dc['step'])), palette=palette_ident)
     palette = {k: v for k, v in zip(np.unique(res_dc['step']), palette)}
     df = pd.DataFrame(data=res_dc)
@@ -199,13 +199,13 @@ def get_timing_plots(base_dir):
             x_fit = np.linspace(np.min(x), np.max(x), 1000)
             y_fit = res.params[1] * x_fit + res.params[0]
             # plt.plot(x_fit, y_fit, color=palette[step])
-        plt.yscale('log')
+        # plt.yscale('log')
         plt.xticks(np.arange(8, 28, step=4))
         axes.spines['right'].set_visible(False)
         axes.spines['top'].set_visible(False)
         axes.legend(*axes.get_legend_handles_labels(), bbox_to_anchor=(1.05, 1),
                     loc='upper left', borderaxespad=0.)
-        axes.set_ylabel('time [h] (log scale)')
+        axes.set_ylabel('time [h]')
         axes.set_xlabel('no. compute nodes [1]')
         plt.subplots_adjust(right=0.75)
         plt.savefig(base_dir + '/timing_allsteps_regplot_diff_nodes.png', dpi=600)
@@ -237,13 +237,13 @@ def get_timing_plots(base_dir):
             x_fit = np.linspace(np.min(x), np.max(x), 1000)
             y_fit = res.params[1] * x_fit + res.params[0]
             # plt.plot(x_fit, y_fit, color=palette[step])
-        plt.yscale('log')
+        # plt.yscale('log')
         plt.xticks(np.arange(8, 28, step=4))
         axes.spines['right'].set_visible(False)
         axes.spines['top'].set_visible(False)
         axes.legend(*axes.get_legend_handles_labels(), bbox_to_anchor=(1.05, 1),
                     loc='upper left', borderaxespad=0.)
-        axes.set_ylabel('time [h] (log scale)')
+        axes.set_ylabel('time [h]')
         axes.set_xlabel('no. compute nodes [1]')
         plt.subplots_adjust(right=0.75)
         plt.savefig(base_dir + '/time_allsteps_regplot_diff_nodes_wo_views.png', dpi=600)
@@ -325,12 +325,12 @@ def get_timing_plots(base_dir):
             x_fit = np.linspace(np.min(x), np.max(x), 1000)
             y_fit = res.params[1] * x_fit + res.params[0]
             plt.plot(x_fit, y_fit, color=palette[step])
-        plt.yscale('log')
+        # plt.yscale('log')
         axes.spines['right'].set_visible(False)
         axes.spines['top'].set_visible(False)
         axes.legend(*axes.get_legend_handles_labels(), bbox_to_anchor=(1.05, 1),
                     loc='upper left', borderaxespad=0.)
-        axes.set_ylabel('time [h] (log scale)')
+        axes.set_ylabel('time [h]')
         axes.set_xlabel('size [GVx]')
         plt.subplots_adjust(right=0.75)
         plt.savefig(base_dir + '/timing_allsteps_regplot.png', dpi=600)
@@ -383,12 +383,12 @@ def get_timing_plots(base_dir):
             x_fit = np.linspace(np.min(x), np.max(x), 1000)
             y_fit = res.params[1] * x_fit + res.params[0]
             plt.plot(x_fit, y_fit, color=palette[step])
-        plt.yscale('log')
+        # plt.yscale('log')
         axes.spines['right'].set_visible(False)
         axes.spines['top'].set_visible(False)
         axes.legend(*axes.get_legend_handles_labels(), bbox_to_anchor=(1.05, 1),
                     loc='upper left', borderaxespad=0.)
-        axes.set_ylabel('time [h] (log scale)')
+        axes.set_ylabel('time [h]')
         axes.set_xlabel('size [GVx]')
         plt.subplots_adjust(right=0.75)
         plt.savefig(base_dir + '/timing_allsteps_regplot_wo_views.png', dpi=600)
@@ -397,5 +397,5 @@ def get_timing_plots(base_dir):
 
 if __name__ == '__main__':
     get_timing_plots('/mnt/example_runs/nodes_vs_time/')
-    # get_timing_plots('/mnt/example_runs/vol_vs_time/')
-    # get_speed_plots('/mnt/example_runs/vol_vs_time')
+    get_timing_plots('/mnt/example_runs/vol_vs_time/')
+    get_speed_plots('/mnt/example_runs/vol_vs_time/')
diff --git a/syconn/exec/exec_init.py b/syconn/exec/exec_init.py
index 0ce2f8cd..f21ad16c 100755
--- a/syconn/exec/exec_init.py
+++ b/syconn/exec/exec_init.py
@@ -109,7 +109,7 @@ def run_create_neuron_ssd(apply_ssv_size_threshold: Optional[bool] = None):
 
     log.info('Finished SSD initialization. Starting cellular organelle mapping.')
     # map cellular organelles to SSVs
-    ssd_proc.aggregate_segmentation_object_mappings(ssd, global_params.config['existing_cell_organelles'])
+    ssd_proc.aggregate_segmentation_object_mappings(ssd, global_params.config['existing_cell_organelles'], nb_cpus=2)
     ssd_proc.apply_mapping_decisions(ssd, global_params.config['existing_cell_organelles'])
     log.info('Finished mapping of cellular organelles to SSVs. Writing individual SSV graphs.')
 
diff --git a/syconn/proc/ssd_proc.py b/syconn/proc/ssd_proc.py
index a816acb5..ae088bee 100755
--- a/syconn/proc/ssd_proc.py
+++ b/syconn/proc/ssd_proc.py
@@ -11,6 +11,7 @@ from ..mp import batchjob_utils as qu
 from ..mp import mp_utils as sm
 from ..proc.meshes import mesh_creator_sso
 from ..reps import segmentation, super_segmentation
+from ..reps.segmentation_helper import prepare_so_attr_cache
 from ..reps.super_segmentation import SuperSegmentationObject, SuperSegmentationDataset
 
 from typing import Iterable, Tuple
@@ -43,14 +44,11 @@ def aggregate_segmentation_object_mappings(ssd: SuperSegmentationDataset, obj_ty
                      obj_types, ssd.type) for ssv_id_block in multi_params]
 
     if not qu.batchjob_enabled():
-        _ = sm.start_multiprocess_imap(
-            _aggregate_segmentation_object_mappings_thread,
-            multi_params, debug=False, nb_cpus=nb_cpus)
-
+        _ = sm.start_multiprocess_imap(_aggregate_segmentation_object_mappings_thread, multi_params,
+                                       debug=False, nb_cpus=nb_cpus)
     else:
-        _ = qu.batchjob_script(
-            multi_params, "aggregate_segmentation_object_mappings",
-            n_cores=nb_cpus, remove_jobfolder=True)
+        _ = qu.batchjob_script(multi_params, "aggregate_segmentation_object_mappings", n_cores=nb_cpus,
+                               remove_jobfolder=True)
 
 
 def _aggregate_segmentation_object_mappings_thread(args):
@@ -63,23 +61,28 @@ def _aggregate_segmentation_object_mappings_thread(args):
 
     ssd = super_segmentation.SuperSegmentationDataset(working_dir, version, ssd_type=ssd_type,
                                                       version_dict=version_dict)
+    svids = np.concatenate([ssd.mapping_dict[ssvid] for ssvid in ssv_obj_ids])
+    ssd._mapping_dict = None
     so_attr_of_interest = []
     # create cache for object attributes
     for obj_type in obj_types:
         so_attr_of_interest.extend([f"mapping_{obj_type}_ids", f"mapping_{obj_type}_ratios"])
-    sd_cell = segmentation.SegmentationDataset('sv', config=ssd.config, cache_properties=so_attr_of_interest)
+    attr_cache = prepare_so_attr_cache(segmentation.SegmentationDataset('sv', config=ssd.config), svids,
+                                       so_attr_of_interest)
 
     for ssv_id in ssv_obj_ids:
         ssv = ssd.get_super_segmentation_object(ssv_id)
+        ssv.load_attr_dict()
         mappings = dict((obj_type, Counter()) for obj_type in obj_types)
         for svid in ssv.sv_ids:
-            sv = sd_cell.get_segmentation_object(svid)
             for obj_type in obj_types:
-                if f"mapping_{obj_type}_ids" in sv.attr_dict:
-                    keys = sv.attr_dict[f"mapping_{obj_type}_ids"]
-                    values = sv.attr_dict[f"mapping_{obj_type}_ratios"]
+                try:
+                    keys = attr_cache[f"mapping_{obj_type}_ids"][svid]
+                    values = attr_cache[f"mapping_{obj_type}_ratios"][svid]
                     mappings[obj_type] += Counter(dict(zip(keys, values)))
-        ssv.load_attr_dict()
+                except KeyError:
+                    raise KeyError(f'Could not find attribute "{f"mapping_{obj_type}_ids"}" for '
+                                   f'cell supervoxel {svid} during "_aggregate_segmentation_object_mappings_thread".')
         for obj_type in obj_types:
             if obj_type in mappings:
                 ssv.attr_dict[f"mapping_{obj_type}_ids"] = list(mappings[obj_type].keys())
diff --git a/syconn/reps/segmentation.py b/syconn/reps/segmentation.py
index 9a54882d..fb1abd68 100755
--- a/syconn/reps/segmentation.py
+++ b/syconn/reps/segmentation.py
@@ -1845,7 +1845,7 @@ class SegmentationDataset(SegmentationBase):
 
         so = SegmentationObject(**kwargs_def)
         for k, v in self._property_cache.items():
-            so.attr_dict[k] = v[self._soid2ix[obj_id]]
+            so.attr_dict[k] = v[self.soid2ix[obj_id]]
         return so
 
     def save_version_dict(self):
@@ -1863,6 +1863,12 @@ class SegmentationDataset(SegmentationBase):
         except Exception as e:
             raise FileNotFoundError('Version dictionary of SegmentationDataset not found. {}'.format(str(e)))
 
+    @property
+    def soid2ix(self):
+        if self._soid2ix is None:
+            self._soid2ix = {k: ix for ix, k in enumerate(self.ids)}
+        return self._soid2ix
+
     def enable_property_cache(self, property_keys: Iterable[str]):
         """
         Add properties to cache.
@@ -1877,8 +1883,8 @@ class SegmentationDataset(SegmentationBase):
                 property_keys.remove(k)
         if len(property_keys) == 0:
             return
-        if self._soid2ix is None:
-            self._soid2ix = {k: ix for ix, k in enumerate(self.ids)}
+        # init index array
+        _ = self.soid2ix
         self._property_cache.update({k: self.load_numpy_data(k, allow_nonexisting=False) for k in property_keys})
 
     def get_volume(self, source: str = 'total') -> float:
diff --git a/syconn/reps/segmentation_helper.py b/syconn/reps/segmentation_helper.py
index 05688d27..7fdbd736 100755
--- a/syconn/reps/segmentation_helper.py
+++ b/syconn/reps/segmentation_helper.py
@@ -23,7 +23,7 @@ from ..mp.mp_utils import start_multiprocess_imap
 from ..proc.graphs import create_graph_from_coords
 
 if TYPE_CHECKING:
-    from ..reps.segmentation import SegmentationObject
+    from ..reps.segmentation import SegmentationObject, SegmentationDataset
 MeshType = Union[Tuple[np.ndarray, np.ndarray, np.ndarray], List[np.ndarray],
                  Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]
 
@@ -405,11 +405,11 @@ def sv_attr_exists(args):
     return missing_ids
 
 
-def find_missing_sv_attributes(sd, attr_key, n_cores=20):
+def find_missing_sv_attributes(sd: 'SegmentationDataset', attr_key: str, n_cores: int = 20):
     """
 
     Args:
-        sd: SegmentationDataset
+        sd:
         attr_key: str
         n_cores: int
 
@@ -521,6 +521,30 @@ def load_so_attr_bulk(sos: List['SegmentationObject'],
     return out
 
 
+def prepare_so_attr_cache(sd: 'SegmentationDataset', so_ids: np.ndarray, attr_keys: List[str]) -> Dict[str, dict]:
+    """
+
+    Args:
+        sd: SegmentationDataset.
+        so_ids: SegmentationObject IDs for which to collect the attributes.
+        attr_keys: Attribute keys to collect. Corresponding numyp arrays must exist.
+
+    Returns:
+        Dictionary with `attr_keys` as keys and an attribute dictionary as values for the IDs `so_ids`, e.g.
+        ``attr_cache[attr_keys[0]][so_ids[0]]`` will return the attribute value of type ``attr_keys[0]`` for the first
+        SegmentatonObect in `so_ids`.
+    """
+    attr_cache = {k: dict() for k in attr_keys}
+    soid2ix = {so_id: sd.soid2ix[so_id] for so_id in so_ids}
+    sd._soid2ix = None  # free memory
+    for attr in attr_keys:
+        np_cache = sd.load_numpy_data(attr, allow_nonexisting=False)
+        for so_id in so_ids:
+            attr_cache[attr][so_id] = np_cache[soid2ix[so_id]]
+        del np_cache
+    return attr_cache
+
+
 def load_so_voxels_bulk(sos: List['SegmentationObject'],
                         use_new_subfold: bool = True, cache_decomp=True):
     """
-- 
GitLab