From 88b0672fa8f3531c3f8b471e858336eb51a20ecb Mon Sep 17 00:00:00 2001 From: arother <rother@neuro.mpg.de> Date: Tue, 25 Apr 2023 17:22:16 +0200 Subject: [PATCH] updated str2int, int2str_converter with glia cell classes --- syconn/extraction/cs_processing_steps.py | 14 ++++++++++++++ syconn/handler/prediction.py | 9 +++++++++ 2 files changed, 23 insertions(+) diff --git a/syconn/extraction/cs_processing_steps.py b/syconn/extraction/cs_processing_steps.py index 6108f789..0f202f01 100755 --- a/syconn/extraction/cs_processing_steps.py +++ b/syconn/extraction/cs_processing_steps.py @@ -1338,17 +1338,23 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st df = pandas.read_excel(path2file, header=0, names=[ 'ixs', 'coord', 'pre', 'post', 'syn', 'doublechecked', 'triplechecked', '?', 'comments']).values df = df[:, :7] + synaptic = 0 + non_synaptic = 0 for ix in range(df.shape[0]): c_orig = df[ix, 5] c = df[ix, 6] if type(c) != float and 'yes' in c: unified_comment = 'synaptic' + synaptic += 1 elif type(c) != float and 'no' in c: unified_comment = 'non-synaptic' + non_synaptic += 1 elif 'yes' in c_orig: unified_comment = 'synaptic' + synaptic += 1 elif 'no' in c_orig: unified_comment = 'non-synaptic' + non_synaptic += 1 else: log.warn(f'Did not understand GT comment "{c}". Skipping') continue @@ -1357,6 +1363,7 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st labels = np.array(labels) label_coords = np.array(label_coords) + log.info(f'Before filtering: {synaptic} synaptic labels and {non_synaptic} non_synaptic labels') # get deterministic order by sorting by coordinate first and then seeded shuffling ixs = [i[0] for i in sorted(enumerate(label_coords), @@ -1389,8 +1396,15 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st if np.sum(mask) == 0: raise ValueError synssv_ids = synssv_ids[mask] + not_mapped_labels = labels[mask == 0] + not_mapped_syn = not_mapped_labels[not_mapped_labels == 'synaptic'] + not_mapped_nonsyn = not_mapped_labels[not_mapped_labels == 'non-synaptic'] labels = labels[mask] + mapped_syn = labels[labels == 'synaptic'] + mapped_nonsyn = labels[labels == 'non-synaptic'] log.info(f'Found {np.sum(mask)}/{len(mask)} samples with a distance < {max_dist_vx} vx to the target.') + log.info(f'Excluding {len(not_mapped_syn)} synaptic labels and {len(not_mapped_nonsyn)} non_synaptic labels after filtering') + log.info(f'Training with {len(mapped_syn)} synaptic labels and {len(mapped_nonsyn)} non_synaptic labels after filtering') log.info(f'Synapse features will now be generated.') features = [] diff --git a/syconn/handler/prediction.py b/syconn/handler/prediction.py index caaf35b8..761081e4 100755 --- a/syconn/handler/prediction.py +++ b/syconn/handler/prediction.py @@ -1324,6 +1324,10 @@ def str2int_converter(comment: str, gt_type: str) -> int: str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7, FS=8, LTS=9, NGF=10) return str2int_label[comment] + elif gt_type == 'ctgt_j0251_v3': + str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7, + FS=8, LTS=9, NGF=10, ASTRO=11, OLIGO=12, MICRO=13, FRAG=14) + return str2int_label[comment] else: raise ValueError("Given groundtruth type is not valid.") @@ -1400,5 +1404,10 @@ def int2str_converter(label: int, gt_type: str) -> str: FS=8, LTS=9, NGF=10) int2str_label = {v: k for k, v in str2int_label.items()} return int2str_label[label] + elif gt_type == 'ctgt_j0251_v3': + str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7, + FS=8, LTS=9, NGF=10, ASTRO=11, OLIGO=12, MICRO=13, FRAG=14) + int2str_label = {v: k for k, v in str2int_label.items()} + return int2str_label[label] else: raise ValueError("Given ground truth type is not valid.") -- GitLab