From 886cae82b51a56afb64edb9abc181c1e9f673d1e Mon Sep 17 00:00:00 2001
From: Philipp Schubert <p.schubert@stud.uni-heidelberg.de>
Date: Mon, 23 Aug 2021 13:09:55 +0200
Subject: [PATCH] minor changes

---
 scripts/point_party/semseg_gt.py         |  22 ++--
 syconn/cnn/TrainData.py                  |   7 +-
 syconn/cnn/cnn_semseg_lcp.py             |  48 +++++----
 syconn/extraction/cs_processing_steps.py | 128 ++++++++++++-----------
 4 files changed, 111 insertions(+), 94 deletions(-)

diff --git a/scripts/point_party/semseg_gt.py b/scripts/point_party/semseg_gt.py
index 3d900905..e49da2ec 100755
--- a/scripts/point_party/semseg_gt.py
+++ b/scripts/point_party/semseg_gt.py
@@ -281,17 +281,7 @@ j0251: 'dendrite': 0, 'axon': 1, 'soma': 2, 'bouton': 3, 'terminal': 4, 'neck':
        'nr': 7, 'in': 8, 'p': 9, 'st': 10, 'ignore': 11, 'merger': 12, 'pure_dendrite': 13,
        'pure_axon': 14}
 """
-# j0251 mappings
-label_mappings = dict(fine=[(7, 5), (8, 5), (9, 5), (10, 6)],
-                      # map nr, in, p, neck to "neck" (1) and st, head to "head" (2).
-                      dnh=[(7, 1), (8, 1), (9, 1), (5, 1), (10, 2), (6, 2)],
-                      # map bouton to 1 and terminal to 2
-                      abt=[(3, 1), (4, 2)],
-                      # map all dendritic compartments to dendrite (0) and all axonic to axon (1)
-                      ads=[(7, 0), (8, 0), (9, 0), (5, 0), (10, 0), (6, 0), (3, 1), (4, 1)],
-                      )
-
-# j0251 ignore labels
+# j0251 ignore labels - is applied before label_mappings from below!
 label_remove = dict(
     # ignore "ignore", merger, pure dendrite and pure axon (TODO: what are those?!)
     fine=[11, 12, 13, 14, 15, -1],
@@ -303,6 +293,16 @@ label_remove = dict(
     ads=[11, 12, 13, 14, 15, -1],
 )
 
+# j0251 mappings
+label_mappings = dict(fine=[(7, 5), (8, 5), (9, 5), (10, 6)],  # st (10; stubby) to "head"
+                      # map nr, in, p, neck to "neck" (1) and head, st (10; stubby) to "head" (2)., dendrite stays 0
+                      dnh=[(7, 1), (8, 1), (9, 1), (5, 1), (10, 2), (6, 2)],
+                      # map axon to 0, bouton to 1 and terminal to 2
+                      abt=[(1, 0), (3, 1), (4, 2)],
+                      # map all dendritic compartments to dendrite (0) and all axonic to axon (1)
+                      ads=[(7, 0), (8, 0), (9, 0), (5, 0), (10, 0), (6, 0), (3, 1), (4, 1)],
+                      )
+
 class_nums = dict(fine=7, dnh=3, abt=3, ads=3)
 target_names = dict(fine=['dendrite', 'axon', 'soma', 'bouton', 'terminal', 'neck', 'head'],
                     dnh=['dendrite', 'neck', 'head'],
diff --git a/syconn/cnn/TrainData.py b/syconn/cnn/TrainData.py
index e4a432d3..04fa98f0 100755
--- a/syconn/cnn/TrainData.py
+++ b/syconn/cnn/TrainData.py
@@ -449,7 +449,8 @@ if elektronn3_avail:
 
     class CloudDataSemseg(Dataset):
         def __init__(self, source_dir=None, npoints=12000, transform: Callable = Identity(),
-                     train=True, batch_size=2, use_subcell=True, ctx_size=8000, mask_borders_with_id=None):
+                     train=True, batch_size=2, use_subcell=True, ctx_size=8000, mask_borders_with_id=None,
+                     remap_dict: Optional[dict] = None):
             if source_dir is None:
                 # source_dir = '/wholebrain/songbird/j0126/GT/compartment_gt_2020/2020_05//hc_out_2020_08/'
                 # ssv_ids_proof = [34811392, 26501121, 2854913, 37558272, 33581058, 491527, 16096256, 10919937, 46319619,
@@ -477,10 +478,14 @@ if elektronn3_avail:
             self._batch_size = batch_size
             self.transform = transform
             self.mask_borders_with_id = mask_borders_with_id
+            self.remap_dict = remap_dict
 
         def __getitem__(self, item):
             item = np.random.randint(0, len(self.fnames))
             sample_pts, sample_feats, out_labels = self.load_sample(item)
+            if self.remap_dict is not None:
+                for k, v in self.remap_dict.items():
+                    out_labels[out_labels == k] = v
             pts = torch.from_numpy(sample_pts).float()
             feats = torch.from_numpy(sample_feats).float()
             lbs = torch.from_numpy(out_labels).long()
diff --git a/syconn/cnn/cnn_semseg_lcp.py b/syconn/cnn/cnn_semseg_lcp.py
index 020c3a5b..a6fdadf3 100755
--- a/syconn/cnn/cnn_semseg_lcp.py
+++ b/syconn/cnn/cnn_semseg_lcp.py
@@ -69,16 +69,26 @@ normalize_pts = True
 eval_nr = random_seed  # number of repetition
 cellshape_only = False
 use_syntype = False
-# 'dendrite': 0, 'axon': 1, 'soma': 2, 'bouton': 3, 'terminal': 4, 'neck': 5, 'head': 6
-num_classes = 7
+# ads: axon dendrite soma
+# abt: axon bouton terminal
+# fine: 'dendrite': 0, 'axon': 1, 'soma': 2, 'bouton': 3, 'terminal': 4, 'neck': 5, 'head': 6
+gt_type = 'dnh'
+num_classes = {'ads': 3, 'abt': 3, 'dnh': 3, 'fine': 7}
+ignore_l = num_classes[gt_type]  # num_classes is also used as ignore label
+remap_dicts = {'ads': {3: 1, 4: 1, 5: 2, 6: 2},
+               'abt': {0: ignore_l, 2: ignore_l, 5: ignore_l, 6: ignore_l, 1: 0, 3: 1, 4: 2},
+               'dnh': {1: ignore_l, 2: ignore_l, 3: ignore_l, 4: ignore_l, 5: 1, 6: 2},
+               'fine': {}}
+weights = dict(ads=[1, 1, 1], abt=[1, 2, 2], dnh=[1, 2, 2], fine=[1, 1, 1, 2, 8, 4, 8])
+
 use_subcell = True
 if cellshape_only:
     use_subcell = False
     use_syntype = False
 
 if name is None:
-    name = f'semseg_pts_nb{npoints}_ctx{ctx}_nclass' \
-           f'{num_classes}_ptconv_noScale_BN_strongerWeighted_noKernelSep_noBatchAvg'
+    name = f'semseg_pts_nb{npoints}_ctx{ctx}_{gt_type}_nclass' \
+           f'{num_classes[gt_type]}_ptconv_BN_strongerWeighted_noKernelSep'
     if not normalize_pts:
         name += '_NonormPts'
     if cellshape_only:
@@ -99,7 +109,7 @@ print(f'Running on device: {device}')
 
 # set paths
 if save_root is None:
-    save_root = '/wholebrain/scratch/pschuber/e3_trainings_ptconv_semseg_j0251_July2021/'
+    save_root = '/wholebrain/scratch/pschuber/e3_trainings_ptconv_semseg_j0251_August2021/'
 save_root = os.path.expanduser(save_root)
 
 # CREATE NETWORK AND PREPARE DATA SET
@@ -108,7 +118,7 @@ save_root = os.path.expanduser(save_root)
 search = 'SearchQuantized'
 conv = dict(layer='ConvPoint', kernel_separation=False, normalize_pts=normalize_pts)
 act = nn.ReLU
-model = ConvAdaptSeg(input_channels, num_classes, get_conv(conv), get_search(search), kernel_num=64,
+model = ConvAdaptSeg(input_channels, num_classes[gt_type], get_conv(conv), get_search(search), kernel_num=64,
                      architecture=None, activation=act, norm='bn')
 
 name += f'_eval{eval_nr}'
@@ -134,22 +144,22 @@ elif args.jit == 'train':
 # Transformations to be applied to samples before feeding them to the network
 train_transform = clouds.Compose([clouds.RandomVariation((-20, 20), distr='normal'),  # in nm
                                   clouds.Center(),
-                                  # clouds.Normalization(scale_norm),
+                                  clouds.Normalization(scale_norm),
                                   clouds.RandomRotate(apply_flip=True),
                                   clouds.ElasticTransform(res=(40, 40, 40), sigma=6),
                                   clouds.RandomScale(distr_scale=0.05, distr='uniform')])
 valid_transform = clouds.Compose([clouds.Center(),
-                                  # clouds.Normalization(scale_norm)
+                                  clouds.Normalization(scale_norm)
                                   ])
 
 # mask boarder points with 'num_classes' and set its weight to 0
 source_dir = '/wholebrain/songbird/j0251/groundtruth/compartment_gt/2021_06_30_more_samples/hc_out_2021_06/'
 train_ds = CloudDataSemseg(npoints=npoints, transform=train_transform, use_subcell=use_subcell,
-                           batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=num_classes,
-                           source_dir=source_dir)
+                           batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=ignore_l,
+                           source_dir=source_dir, remap_dict=remap_dicts[gt_type])
 valid_ds = CloudDataSemseg(npoints=npoints, transform=valid_transform, train=False, use_subcell=use_subcell,
-                           batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=num_classes,
-                           source_dir=source_dir)
+                           batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=ignore_l,
+                           source_dir=source_dir, remap_dict=remap_dicts[gt_type])
 
 # PREPARE AND START TRAINING #
 
@@ -177,8 +187,8 @@ lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
 #     gamma=0.99994,
 # )
 # set weight of the masking label at context boarders to 0
-class_weights = torch.tensor([1, 1, 1, 2, 8, 4, 8], dtype=torch.float32, device=device)
-criterion = torch.nn.CrossEntropyLoss(weight=class_weights, ignore_index=num_classes).to(device)
+class_weights = torch.tensor(weights[gt_type], dtype=torch.float32, device=device)
+criterion = torch.nn.CrossEntropyLoss(weight=class_weights, ignore_index=ignore_l).to(device)
 
 valid_metrics = {  # mean metrics
     'val_accuracy_mean': metrics.Accuracy(),
@@ -187,11 +197,11 @@ valid_metrics = {  # mean metrics
     'val_DSC_mean': metrics.DSC(),
     'val_IoU_mean': metrics.IoU(),
 }
-if num_classes > 2:
+if num_classes[gt_type] > 2:
     # Add separate per-class accuracy metrics only if there are more than 2 classes
     valid_metrics.update({
         f'val_IoU_c{i}': metrics.Accuracy(i)
-        for i in range(num_classes)
+        for i in range(num_classes[gt_type])
     })
 
 # Create trainer
@@ -211,11 +221,11 @@ trainer = Trainer3d(
     enable_save_trace=enable_save_trace,
     exp_name=name,
     schedulers={"lr": lr_sched},
-    num_classes=num_classes,
+    num_classes=num_classes[gt_type],
     # example_input=example_input,
     dataloader_kwargs=dict(collate_fn=lambda x: x[0]),
-    nbatch_avg=1,
-    tqdm_kwargs=dict(disable=True),
+    nbatch_avg=4,
+    tqdm_kwargs=dict(disable=False),
     lcp_flag=True
 )
 
diff --git a/syconn/extraction/cs_processing_steps.py b/syconn/extraction/cs_processing_steps.py
index c40829e5..483f24cf 100755
--- a/syconn/extraction/cs_processing_steps.py
+++ b/syconn/extraction/cs_processing_steps.py
@@ -323,7 +323,7 @@ def combine_and_split_syn(wd, cs_gap_nm=300, ssd_version=None, syn_version=None,
     log_extraction.debug(f'Filtering relevant synapses done.')
     storage_location_ids = get_unique_subfold_ixs(n_folders_fs)
 
-    n_used_paths = min(global_params.config.ncore_total * 4, len(storage_location_ids),
+    n_used_paths = min(global_params.config.ncore_total * 8, len(storage_location_ids),
                        len(rel_ssv_with_syn_ids))
     voxel_rel_paths = chunkify([subfold_from_ix(ix, n_folders_fs) for ix in storage_location_ids],
                                n_used_paths)
@@ -422,68 +422,70 @@ def _combine_and_split_syn_thread(args):
 
         voxel_list = np.concatenate(voxel_list)
         for this_cc in ccs:
-            this_cc_mask = np.array(list(this_cc))
-            # retrieve the index of the syn objects selected for this CC
-            this_syn_ixs, this_syn_ids_cnt = np.unique(synix_list[this_cc_mask],
-                                                       return_counts=True)
-            # the weight is important
-            this_agg_syn_weights = this_syn_ids_cnt / np.sum(this_syn_ids_cnt)
-            if np.sum(this_syn_ids_cnt) < cell_obj_cnf['min_obj_vx']['syn_ssv']:
-                continue
-            this_attr = syn_attr_list[this_syn_ixs]
-            this_vx = voxel_list[this_cc_mask]
-            syn_ssv = sd_syn_ssv.get_segmentation_object(syn_ssv_id)
-            if (os.path.abspath(syn_ssv.attr_dict_path)
-                    != os.path.abspath(base_dir + "/attr_dict.pkl")):
-                raise ValueError(f'Path mis-match!')
-            synssv_attr_dc = dict(neuron_partners=ssv_ids)
-            voxel_dc[syn_ssv_id] = this_vx
-            synssv_attr_dc["rep_coord"] = this_vx[len(this_vx) // 2]  # any rep coord
-            synssv_attr_dc["bounding_box"] = np.array([np.min(this_vx, axis=0), np.max(this_vx, axis=0)])
-            synssv_attr_dc["size"] = len(this_vx)
-            # calc_contact_syn_mesh returns a list with a single mesh (for syn_ssv)
-            if mesh_min_obj_vx < synssv_attr_dc["size"]:
-                syn_ssv._mesh = calc_contact_syn_mesh(syn_ssv, voxel_dc=voxel_dc, **syn_meshing_kws)[0]
-                mesh_dc[syn_ssv.id] = syn_ssv.mesh
-                synssv_attr_dc["mesh_bb"] = syn_ssv.mesh_bb
-                synssv_attr_dc["mesh_area"] = syn_ssv.mesh_area
-            else:
-                zero_mesh = [np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32),
-                             np.zeros((0,), dtype=np.float32)]
-                mesh_dc[syn_ssv.id] = zero_mesh
-                synssv_attr_dc["mesh_bb"] = synssv_attr_dc["bounding_box"] * scaling
-                synssv_attr_dc["mesh_area"] = 0
-            # aggregate syn properties
-            syn_props_agg = {}
-            # cs_id is the same as syn_id ('syn' are just a subset of 'cs')
-            for dc in this_attr:
-                for k in ['id_cs_ratio', 'cs_id', 'sym_prop', 'asym_prop']:
-                    syn_props_agg.setdefault(k, []).append(dc[k])
-            # rename and delete old entry
-            syn_props_agg['cs_ids'] = syn_props_agg['cs_id']
-            del syn_props_agg['cs_id']
-
-            # use the fraction of 'syn' voxels used for this connected component, i.e. 'this_agg_syn_weights', as weight
-            # agglomerate the syn-to-cs ratio as a weighted sum
-            syn_props_agg['id_cs_ratio'] = np.sum(this_agg_syn_weights * np.array(syn_props_agg['id_cs_ratio']))
-
-            # 'syn_ssv' synapse type as weighted sum of the 'syn' fragment types
-            sym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['sym_prop']))
-            asym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['asym_prop']))
-            syn_props_agg['sym_prop'] = sym_prop
-            syn_props_agg['asym_prop'] = asym_prop
-
-            if sym_prop + asym_prop == 0:
-                sym_ratio = -1
-            else:
-                sym_ratio = sym_prop / float(asym_prop + sym_prop)
-            syn_props_agg["syn_type_sym_ratio"] = sym_ratio
-            syn_sign = -1 if sym_ratio > cell_obj_cnf['sym_thresh'] else 1
-            syn_props_agg["syn_sign"] = syn_sign
-
-            # add syn_ssv dict to AttributeStorage
-            synssv_attr_dc.update(syn_props_agg)
-            attr_dc[syn_ssv_id] = synssv_attr_dc
+            # do not process synapse again if job has been restarted
+            if syn_ssv_id not in attr_dc:
+                this_cc_mask = np.array(list(this_cc))
+                # retrieve the index of the syn objects selected for this CC
+                this_syn_ixs, this_syn_ids_cnt = np.unique(synix_list[this_cc_mask],
+                                                           return_counts=True)
+                # the weight is important
+                this_agg_syn_weights = this_syn_ids_cnt / np.sum(this_syn_ids_cnt)
+                if np.sum(this_syn_ids_cnt) < cell_obj_cnf['min_obj_vx']['syn_ssv']:
+                    continue
+                this_attr = syn_attr_list[this_syn_ixs]
+                this_vx = voxel_list[this_cc_mask]
+                syn_ssv = sd_syn_ssv.get_segmentation_object(syn_ssv_id)
+                if (os.path.abspath(syn_ssv.attr_dict_path)
+                        != os.path.abspath(base_dir + "/attr_dict.pkl")):
+                    raise ValueError(f'Path mis-match!')
+                synssv_attr_dc = dict(neuron_partners=ssv_ids)
+                voxel_dc[syn_ssv_id] = this_vx
+                synssv_attr_dc["rep_coord"] = this_vx[len(this_vx) // 2]  # any rep coord
+                synssv_attr_dc["bounding_box"] = np.array([np.min(this_vx, axis=0), np.max(this_vx, axis=0)])
+                synssv_attr_dc["size"] = len(this_vx)
+                # calc_contact_syn_mesh returns a list with a single mesh (for syn_ssv)
+                if mesh_min_obj_vx < synssv_attr_dc["size"]:
+                    syn_ssv._mesh = calc_contact_syn_mesh(syn_ssv, voxel_dc=voxel_dc, **syn_meshing_kws)[0]
+                    mesh_dc[syn_ssv.id] = syn_ssv.mesh
+                    synssv_attr_dc["mesh_bb"] = syn_ssv.mesh_bb
+                    synssv_attr_dc["mesh_area"] = syn_ssv.mesh_area
+                else:
+                    zero_mesh = [np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32),
+                                 np.zeros((0,), dtype=np.float32)]
+                    mesh_dc[syn_ssv.id] = zero_mesh
+                    synssv_attr_dc["mesh_bb"] = synssv_attr_dc["bounding_box"] * scaling
+                    synssv_attr_dc["mesh_area"] = 0
+                # aggregate syn properties
+                syn_props_agg = {}
+                # cs_id is the same as syn_id ('syn' are just a subset of 'cs')
+                for dc in this_attr:
+                    for k in ['id_cs_ratio', 'cs_id', 'sym_prop', 'asym_prop']:
+                        syn_props_agg.setdefault(k, []).append(dc[k])
+                # rename and delete old entry
+                syn_props_agg['cs_ids'] = syn_props_agg['cs_id']
+                del syn_props_agg['cs_id']
+
+                # use the fraction of 'syn' voxels used for this connected component, i.e. 'this_agg_syn_weights', as weight
+                # agglomerate the syn-to-cs ratio as a weighted sum
+                syn_props_agg['id_cs_ratio'] = np.sum(this_agg_syn_weights * np.array(syn_props_agg['id_cs_ratio']))
+
+                # 'syn_ssv' synapse type as weighted sum of the 'syn' fragment types
+                sym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['sym_prop']))
+                asym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['asym_prop']))
+                syn_props_agg['sym_prop'] = sym_prop
+                syn_props_agg['asym_prop'] = asym_prop
+
+                if sym_prop + asym_prop == 0:
+                    sym_ratio = -1
+                else:
+                    sym_ratio = sym_prop / float(asym_prop + sym_prop)
+                syn_props_agg["syn_type_sym_ratio"] = sym_ratio
+                syn_sign = -1 if sym_ratio > cell_obj_cnf['sym_thresh'] else 1
+                syn_props_agg["syn_sign"] = syn_sign
+
+                # add syn_ssv dict to AttributeStorage
+                synssv_attr_dc.update(syn_props_agg)
+                attr_dc[syn_ssv_id] = synssv_attr_dc
             if use_new_subfold:
                 syn_ssv_id += np.uint(1)
                 if syn_ssv_id - base_id >= div_base:
-- 
GitLab