From 886cae82b51a56afb64edb9abc181c1e9f673d1e Mon Sep 17 00:00:00 2001 From: Philipp Schubert <p.schubert@stud.uni-heidelberg.de> Date: Mon, 23 Aug 2021 13:09:55 +0200 Subject: [PATCH] minor changes --- scripts/point_party/semseg_gt.py | 22 ++-- syconn/cnn/TrainData.py | 7 +- syconn/cnn/cnn_semseg_lcp.py | 48 +++++---- syconn/extraction/cs_processing_steps.py | 128 ++++++++++++----------- 4 files changed, 111 insertions(+), 94 deletions(-) diff --git a/scripts/point_party/semseg_gt.py b/scripts/point_party/semseg_gt.py index 3d900905..e49da2ec 100755 --- a/scripts/point_party/semseg_gt.py +++ b/scripts/point_party/semseg_gt.py @@ -281,17 +281,7 @@ j0251: 'dendrite': 0, 'axon': 1, 'soma': 2, 'bouton': 3, 'terminal': 4, 'neck': 'nr': 7, 'in': 8, 'p': 9, 'st': 10, 'ignore': 11, 'merger': 12, 'pure_dendrite': 13, 'pure_axon': 14} """ -# j0251 mappings -label_mappings = dict(fine=[(7, 5), (8, 5), (9, 5), (10, 6)], - # map nr, in, p, neck to "neck" (1) and st, head to "head" (2). - dnh=[(7, 1), (8, 1), (9, 1), (5, 1), (10, 2), (6, 2)], - # map bouton to 1 and terminal to 2 - abt=[(3, 1), (4, 2)], - # map all dendritic compartments to dendrite (0) and all axonic to axon (1) - ads=[(7, 0), (8, 0), (9, 0), (5, 0), (10, 0), (6, 0), (3, 1), (4, 1)], - ) - -# j0251 ignore labels +# j0251 ignore labels - is applied before label_mappings from below! label_remove = dict( # ignore "ignore", merger, pure dendrite and pure axon (TODO: what are those?!) fine=[11, 12, 13, 14, 15, -1], @@ -303,6 +293,16 @@ label_remove = dict( ads=[11, 12, 13, 14, 15, -1], ) +# j0251 mappings +label_mappings = dict(fine=[(7, 5), (8, 5), (9, 5), (10, 6)], # st (10; stubby) to "head" + # map nr, in, p, neck to "neck" (1) and head, st (10; stubby) to "head" (2)., dendrite stays 0 + dnh=[(7, 1), (8, 1), (9, 1), (5, 1), (10, 2), (6, 2)], + # map axon to 0, bouton to 1 and terminal to 2 + abt=[(1, 0), (3, 1), (4, 2)], + # map all dendritic compartments to dendrite (0) and all axonic to axon (1) + ads=[(7, 0), (8, 0), (9, 0), (5, 0), (10, 0), (6, 0), (3, 1), (4, 1)], + ) + class_nums = dict(fine=7, dnh=3, abt=3, ads=3) target_names = dict(fine=['dendrite', 'axon', 'soma', 'bouton', 'terminal', 'neck', 'head'], dnh=['dendrite', 'neck', 'head'], diff --git a/syconn/cnn/TrainData.py b/syconn/cnn/TrainData.py index e4a432d3..04fa98f0 100755 --- a/syconn/cnn/TrainData.py +++ b/syconn/cnn/TrainData.py @@ -449,7 +449,8 @@ if elektronn3_avail: class CloudDataSemseg(Dataset): def __init__(self, source_dir=None, npoints=12000, transform: Callable = Identity(), - train=True, batch_size=2, use_subcell=True, ctx_size=8000, mask_borders_with_id=None): + train=True, batch_size=2, use_subcell=True, ctx_size=8000, mask_borders_with_id=None, + remap_dict: Optional[dict] = None): if source_dir is None: # source_dir = '/wholebrain/songbird/j0126/GT/compartment_gt_2020/2020_05//hc_out_2020_08/' # ssv_ids_proof = [34811392, 26501121, 2854913, 37558272, 33581058, 491527, 16096256, 10919937, 46319619, @@ -477,10 +478,14 @@ if elektronn3_avail: self._batch_size = batch_size self.transform = transform self.mask_borders_with_id = mask_borders_with_id + self.remap_dict = remap_dict def __getitem__(self, item): item = np.random.randint(0, len(self.fnames)) sample_pts, sample_feats, out_labels = self.load_sample(item) + if self.remap_dict is not None: + for k, v in self.remap_dict.items(): + out_labels[out_labels == k] = v pts = torch.from_numpy(sample_pts).float() feats = torch.from_numpy(sample_feats).float() lbs = torch.from_numpy(out_labels).long() diff --git a/syconn/cnn/cnn_semseg_lcp.py b/syconn/cnn/cnn_semseg_lcp.py index 020c3a5b..a6fdadf3 100755 --- a/syconn/cnn/cnn_semseg_lcp.py +++ b/syconn/cnn/cnn_semseg_lcp.py @@ -69,16 +69,26 @@ normalize_pts = True eval_nr = random_seed # number of repetition cellshape_only = False use_syntype = False -# 'dendrite': 0, 'axon': 1, 'soma': 2, 'bouton': 3, 'terminal': 4, 'neck': 5, 'head': 6 -num_classes = 7 +# ads: axon dendrite soma +# abt: axon bouton terminal +# fine: 'dendrite': 0, 'axon': 1, 'soma': 2, 'bouton': 3, 'terminal': 4, 'neck': 5, 'head': 6 +gt_type = 'dnh' +num_classes = {'ads': 3, 'abt': 3, 'dnh': 3, 'fine': 7} +ignore_l = num_classes[gt_type] # num_classes is also used as ignore label +remap_dicts = {'ads': {3: 1, 4: 1, 5: 2, 6: 2}, + 'abt': {0: ignore_l, 2: ignore_l, 5: ignore_l, 6: ignore_l, 1: 0, 3: 1, 4: 2}, + 'dnh': {1: ignore_l, 2: ignore_l, 3: ignore_l, 4: ignore_l, 5: 1, 6: 2}, + 'fine': {}} +weights = dict(ads=[1, 1, 1], abt=[1, 2, 2], dnh=[1, 2, 2], fine=[1, 1, 1, 2, 8, 4, 8]) + use_subcell = True if cellshape_only: use_subcell = False use_syntype = False if name is None: - name = f'semseg_pts_nb{npoints}_ctx{ctx}_nclass' \ - f'{num_classes}_ptconv_noScale_BN_strongerWeighted_noKernelSep_noBatchAvg' + name = f'semseg_pts_nb{npoints}_ctx{ctx}_{gt_type}_nclass' \ + f'{num_classes[gt_type]}_ptconv_BN_strongerWeighted_noKernelSep' if not normalize_pts: name += '_NonormPts' if cellshape_only: @@ -99,7 +109,7 @@ print(f'Running on device: {device}') # set paths if save_root is None: - save_root = '/wholebrain/scratch/pschuber/e3_trainings_ptconv_semseg_j0251_July2021/' + save_root = '/wholebrain/scratch/pschuber/e3_trainings_ptconv_semseg_j0251_August2021/' save_root = os.path.expanduser(save_root) # CREATE NETWORK AND PREPARE DATA SET @@ -108,7 +118,7 @@ save_root = os.path.expanduser(save_root) search = 'SearchQuantized' conv = dict(layer='ConvPoint', kernel_separation=False, normalize_pts=normalize_pts) act = nn.ReLU -model = ConvAdaptSeg(input_channels, num_classes, get_conv(conv), get_search(search), kernel_num=64, +model = ConvAdaptSeg(input_channels, num_classes[gt_type], get_conv(conv), get_search(search), kernel_num=64, architecture=None, activation=act, norm='bn') name += f'_eval{eval_nr}' @@ -134,22 +144,22 @@ elif args.jit == 'train': # Transformations to be applied to samples before feeding them to the network train_transform = clouds.Compose([clouds.RandomVariation((-20, 20), distr='normal'), # in nm clouds.Center(), - # clouds.Normalization(scale_norm), + clouds.Normalization(scale_norm), clouds.RandomRotate(apply_flip=True), clouds.ElasticTransform(res=(40, 40, 40), sigma=6), clouds.RandomScale(distr_scale=0.05, distr='uniform')]) valid_transform = clouds.Compose([clouds.Center(), - # clouds.Normalization(scale_norm) + clouds.Normalization(scale_norm) ]) # mask boarder points with 'num_classes' and set its weight to 0 source_dir = '/wholebrain/songbird/j0251/groundtruth/compartment_gt/2021_06_30_more_samples/hc_out_2021_06/' train_ds = CloudDataSemseg(npoints=npoints, transform=train_transform, use_subcell=use_subcell, - batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=num_classes, - source_dir=source_dir) + batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=ignore_l, + source_dir=source_dir, remap_dict=remap_dicts[gt_type]) valid_ds = CloudDataSemseg(npoints=npoints, transform=valid_transform, train=False, use_subcell=use_subcell, - batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=num_classes, - source_dir=source_dir) + batch_size=batch_size, ctx_size=ctx, mask_borders_with_id=ignore_l, + source_dir=source_dir, remap_dict=remap_dicts[gt_type]) # PREPARE AND START TRAINING # @@ -177,8 +187,8 @@ lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec) # gamma=0.99994, # ) # set weight of the masking label at context boarders to 0 -class_weights = torch.tensor([1, 1, 1, 2, 8, 4, 8], dtype=torch.float32, device=device) -criterion = torch.nn.CrossEntropyLoss(weight=class_weights, ignore_index=num_classes).to(device) +class_weights = torch.tensor(weights[gt_type], dtype=torch.float32, device=device) +criterion = torch.nn.CrossEntropyLoss(weight=class_weights, ignore_index=ignore_l).to(device) valid_metrics = { # mean metrics 'val_accuracy_mean': metrics.Accuracy(), @@ -187,11 +197,11 @@ valid_metrics = { # mean metrics 'val_DSC_mean': metrics.DSC(), 'val_IoU_mean': metrics.IoU(), } -if num_classes > 2: +if num_classes[gt_type] > 2: # Add separate per-class accuracy metrics only if there are more than 2 classes valid_metrics.update({ f'val_IoU_c{i}': metrics.Accuracy(i) - for i in range(num_classes) + for i in range(num_classes[gt_type]) }) # Create trainer @@ -211,11 +221,11 @@ trainer = Trainer3d( enable_save_trace=enable_save_trace, exp_name=name, schedulers={"lr": lr_sched}, - num_classes=num_classes, + num_classes=num_classes[gt_type], # example_input=example_input, dataloader_kwargs=dict(collate_fn=lambda x: x[0]), - nbatch_avg=1, - tqdm_kwargs=dict(disable=True), + nbatch_avg=4, + tqdm_kwargs=dict(disable=False), lcp_flag=True ) diff --git a/syconn/extraction/cs_processing_steps.py b/syconn/extraction/cs_processing_steps.py index c40829e5..483f24cf 100755 --- a/syconn/extraction/cs_processing_steps.py +++ b/syconn/extraction/cs_processing_steps.py @@ -323,7 +323,7 @@ def combine_and_split_syn(wd, cs_gap_nm=300, ssd_version=None, syn_version=None, log_extraction.debug(f'Filtering relevant synapses done.') storage_location_ids = get_unique_subfold_ixs(n_folders_fs) - n_used_paths = min(global_params.config.ncore_total * 4, len(storage_location_ids), + n_used_paths = min(global_params.config.ncore_total * 8, len(storage_location_ids), len(rel_ssv_with_syn_ids)) voxel_rel_paths = chunkify([subfold_from_ix(ix, n_folders_fs) for ix in storage_location_ids], n_used_paths) @@ -422,68 +422,70 @@ def _combine_and_split_syn_thread(args): voxel_list = np.concatenate(voxel_list) for this_cc in ccs: - this_cc_mask = np.array(list(this_cc)) - # retrieve the index of the syn objects selected for this CC - this_syn_ixs, this_syn_ids_cnt = np.unique(synix_list[this_cc_mask], - return_counts=True) - # the weight is important - this_agg_syn_weights = this_syn_ids_cnt / np.sum(this_syn_ids_cnt) - if np.sum(this_syn_ids_cnt) < cell_obj_cnf['min_obj_vx']['syn_ssv']: - continue - this_attr = syn_attr_list[this_syn_ixs] - this_vx = voxel_list[this_cc_mask] - syn_ssv = sd_syn_ssv.get_segmentation_object(syn_ssv_id) - if (os.path.abspath(syn_ssv.attr_dict_path) - != os.path.abspath(base_dir + "/attr_dict.pkl")): - raise ValueError(f'Path mis-match!') - synssv_attr_dc = dict(neuron_partners=ssv_ids) - voxel_dc[syn_ssv_id] = this_vx - synssv_attr_dc["rep_coord"] = this_vx[len(this_vx) // 2] # any rep coord - synssv_attr_dc["bounding_box"] = np.array([np.min(this_vx, axis=0), np.max(this_vx, axis=0)]) - synssv_attr_dc["size"] = len(this_vx) - # calc_contact_syn_mesh returns a list with a single mesh (for syn_ssv) - if mesh_min_obj_vx < synssv_attr_dc["size"]: - syn_ssv._mesh = calc_contact_syn_mesh(syn_ssv, voxel_dc=voxel_dc, **syn_meshing_kws)[0] - mesh_dc[syn_ssv.id] = syn_ssv.mesh - synssv_attr_dc["mesh_bb"] = syn_ssv.mesh_bb - synssv_attr_dc["mesh_area"] = syn_ssv.mesh_area - else: - zero_mesh = [np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32), - np.zeros((0,), dtype=np.float32)] - mesh_dc[syn_ssv.id] = zero_mesh - synssv_attr_dc["mesh_bb"] = synssv_attr_dc["bounding_box"] * scaling - synssv_attr_dc["mesh_area"] = 0 - # aggregate syn properties - syn_props_agg = {} - # cs_id is the same as syn_id ('syn' are just a subset of 'cs') - for dc in this_attr: - for k in ['id_cs_ratio', 'cs_id', 'sym_prop', 'asym_prop']: - syn_props_agg.setdefault(k, []).append(dc[k]) - # rename and delete old entry - syn_props_agg['cs_ids'] = syn_props_agg['cs_id'] - del syn_props_agg['cs_id'] - - # use the fraction of 'syn' voxels used for this connected component, i.e. 'this_agg_syn_weights', as weight - # agglomerate the syn-to-cs ratio as a weighted sum - syn_props_agg['id_cs_ratio'] = np.sum(this_agg_syn_weights * np.array(syn_props_agg['id_cs_ratio'])) - - # 'syn_ssv' synapse type as weighted sum of the 'syn' fragment types - sym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['sym_prop'])) - asym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['asym_prop'])) - syn_props_agg['sym_prop'] = sym_prop - syn_props_agg['asym_prop'] = asym_prop - - if sym_prop + asym_prop == 0: - sym_ratio = -1 - else: - sym_ratio = sym_prop / float(asym_prop + sym_prop) - syn_props_agg["syn_type_sym_ratio"] = sym_ratio - syn_sign = -1 if sym_ratio > cell_obj_cnf['sym_thresh'] else 1 - syn_props_agg["syn_sign"] = syn_sign - - # add syn_ssv dict to AttributeStorage - synssv_attr_dc.update(syn_props_agg) - attr_dc[syn_ssv_id] = synssv_attr_dc + # do not process synapse again if job has been restarted + if syn_ssv_id not in attr_dc: + this_cc_mask = np.array(list(this_cc)) + # retrieve the index of the syn objects selected for this CC + this_syn_ixs, this_syn_ids_cnt = np.unique(synix_list[this_cc_mask], + return_counts=True) + # the weight is important + this_agg_syn_weights = this_syn_ids_cnt / np.sum(this_syn_ids_cnt) + if np.sum(this_syn_ids_cnt) < cell_obj_cnf['min_obj_vx']['syn_ssv']: + continue + this_attr = syn_attr_list[this_syn_ixs] + this_vx = voxel_list[this_cc_mask] + syn_ssv = sd_syn_ssv.get_segmentation_object(syn_ssv_id) + if (os.path.abspath(syn_ssv.attr_dict_path) + != os.path.abspath(base_dir + "/attr_dict.pkl")): + raise ValueError(f'Path mis-match!') + synssv_attr_dc = dict(neuron_partners=ssv_ids) + voxel_dc[syn_ssv_id] = this_vx + synssv_attr_dc["rep_coord"] = this_vx[len(this_vx) // 2] # any rep coord + synssv_attr_dc["bounding_box"] = np.array([np.min(this_vx, axis=0), np.max(this_vx, axis=0)]) + synssv_attr_dc["size"] = len(this_vx) + # calc_contact_syn_mesh returns a list with a single mesh (for syn_ssv) + if mesh_min_obj_vx < synssv_attr_dc["size"]: + syn_ssv._mesh = calc_contact_syn_mesh(syn_ssv, voxel_dc=voxel_dc, **syn_meshing_kws)[0] + mesh_dc[syn_ssv.id] = syn_ssv.mesh + synssv_attr_dc["mesh_bb"] = syn_ssv.mesh_bb + synssv_attr_dc["mesh_area"] = syn_ssv.mesh_area + else: + zero_mesh = [np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32), + np.zeros((0,), dtype=np.float32)] + mesh_dc[syn_ssv.id] = zero_mesh + synssv_attr_dc["mesh_bb"] = synssv_attr_dc["bounding_box"] * scaling + synssv_attr_dc["mesh_area"] = 0 + # aggregate syn properties + syn_props_agg = {} + # cs_id is the same as syn_id ('syn' are just a subset of 'cs') + for dc in this_attr: + for k in ['id_cs_ratio', 'cs_id', 'sym_prop', 'asym_prop']: + syn_props_agg.setdefault(k, []).append(dc[k]) + # rename and delete old entry + syn_props_agg['cs_ids'] = syn_props_agg['cs_id'] + del syn_props_agg['cs_id'] + + # use the fraction of 'syn' voxels used for this connected component, i.e. 'this_agg_syn_weights', as weight + # agglomerate the syn-to-cs ratio as a weighted sum + syn_props_agg['id_cs_ratio'] = np.sum(this_agg_syn_weights * np.array(syn_props_agg['id_cs_ratio'])) + + # 'syn_ssv' synapse type as weighted sum of the 'syn' fragment types + sym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['sym_prop'])) + asym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['asym_prop'])) + syn_props_agg['sym_prop'] = sym_prop + syn_props_agg['asym_prop'] = asym_prop + + if sym_prop + asym_prop == 0: + sym_ratio = -1 + else: + sym_ratio = sym_prop / float(asym_prop + sym_prop) + syn_props_agg["syn_type_sym_ratio"] = sym_ratio + syn_sign = -1 if sym_ratio > cell_obj_cnf['sym_thresh'] else 1 + syn_props_agg["syn_sign"] = syn_sign + + # add syn_ssv dict to AttributeStorage + synssv_attr_dc.update(syn_props_agg) + attr_dc[syn_ssv_id] = synssv_attr_dc if use_new_subfold: syn_ssv_id += np.uint(1) if syn_ssv_id - base_id >= div_base: -- GitLab