Skip to content
Snippets Groups Projects
Commit 88b0672f authored by Alexandra Rother's avatar Alexandra Rother
Browse files

updated str2int, int2str_converter with glia cell classes

parent ccc69590
Branches
Tags
No related merge requests found
Pipeline #164638 failed
...@@ -1338,17 +1338,23 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st ...@@ -1338,17 +1338,23 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st
df = pandas.read_excel(path2file, header=0, names=[ df = pandas.read_excel(path2file, header=0, names=[
'ixs', 'coord', 'pre', 'post', 'syn', 'doublechecked', 'triplechecked', '?', 'comments']).values 'ixs', 'coord', 'pre', 'post', 'syn', 'doublechecked', 'triplechecked', '?', 'comments']).values
df = df[:, :7] df = df[:, :7]
synaptic = 0
non_synaptic = 0
for ix in range(df.shape[0]): for ix in range(df.shape[0]):
c_orig = df[ix, 5] c_orig = df[ix, 5]
c = df[ix, 6] c = df[ix, 6]
if type(c) != float and 'yes' in c: if type(c) != float and 'yes' in c:
unified_comment = 'synaptic' unified_comment = 'synaptic'
synaptic += 1
elif type(c) != float and 'no' in c: elif type(c) != float and 'no' in c:
unified_comment = 'non-synaptic' unified_comment = 'non-synaptic'
non_synaptic += 1
elif 'yes' in c_orig: elif 'yes' in c_orig:
unified_comment = 'synaptic' unified_comment = 'synaptic'
synaptic += 1
elif 'no' in c_orig: elif 'no' in c_orig:
unified_comment = 'non-synaptic' unified_comment = 'non-synaptic'
non_synaptic += 1
else: else:
log.warn(f'Did not understand GT comment "{c}". Skipping') log.warn(f'Did not understand GT comment "{c}". Skipping')
continue continue
...@@ -1357,6 +1363,7 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st ...@@ -1357,6 +1363,7 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st
labels = np.array(labels) labels = np.array(labels)
label_coords = np.array(label_coords) label_coords = np.array(label_coords)
log.info(f'Before filtering: {synaptic} synaptic labels and {non_synaptic} non_synaptic labels')
# get deterministic order by sorting by coordinate first and then seeded shuffling # get deterministic order by sorting by coordinate first and then seeded shuffling
ixs = [i[0] for i in sorted(enumerate(label_coords), ixs = [i[0] for i in sorted(enumerate(label_coords),
...@@ -1389,8 +1396,15 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st ...@@ -1389,8 +1396,15 @@ def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: st
if np.sum(mask) == 0: if np.sum(mask) == 0:
raise ValueError raise ValueError
synssv_ids = synssv_ids[mask] synssv_ids = synssv_ids[mask]
not_mapped_labels = labels[mask == 0]
not_mapped_syn = not_mapped_labels[not_mapped_labels == 'synaptic']
not_mapped_nonsyn = not_mapped_labels[not_mapped_labels == 'non-synaptic']
labels = labels[mask] labels = labels[mask]
mapped_syn = labels[labels == 'synaptic']
mapped_nonsyn = labels[labels == 'non-synaptic']
log.info(f'Found {np.sum(mask)}/{len(mask)} samples with a distance < {max_dist_vx} vx to the target.') log.info(f'Found {np.sum(mask)}/{len(mask)} samples with a distance < {max_dist_vx} vx to the target.')
log.info(f'Excluding {len(not_mapped_syn)} synaptic labels and {len(not_mapped_nonsyn)} non_synaptic labels after filtering')
log.info(f'Training with {len(mapped_syn)} synaptic labels and {len(mapped_nonsyn)} non_synaptic labels after filtering')
log.info(f'Synapse features will now be generated.') log.info(f'Synapse features will now be generated.')
features = [] features = []
......
...@@ -1324,6 +1324,10 @@ def str2int_converter(comment: str, gt_type: str) -> int: ...@@ -1324,6 +1324,10 @@ def str2int_converter(comment: str, gt_type: str) -> int:
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7, str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
FS=8, LTS=9, NGF=10) FS=8, LTS=9, NGF=10)
return str2int_label[comment] return str2int_label[comment]
elif gt_type == 'ctgt_j0251_v3':
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
FS=8, LTS=9, NGF=10, ASTRO=11, OLIGO=12, MICRO=13, FRAG=14)
return str2int_label[comment]
else: else:
raise ValueError("Given groundtruth type is not valid.") raise ValueError("Given groundtruth type is not valid.")
...@@ -1400,5 +1404,10 @@ def int2str_converter(label: int, gt_type: str) -> str: ...@@ -1400,5 +1404,10 @@ def int2str_converter(label: int, gt_type: str) -> str:
FS=8, LTS=9, NGF=10) FS=8, LTS=9, NGF=10)
int2str_label = {v: k for k, v in str2int_label.items()} int2str_label = {v: k for k, v in str2int_label.items()}
return int2str_label[label] return int2str_label[label]
elif gt_type == 'ctgt_j0251_v3':
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
FS=8, LTS=9, NGF=10, ASTRO=11, OLIGO=12, MICRO=13, FRAG=14)
int2str_label = {v: k for k, v in str2int_label.items()}
return int2str_label[label]
else: else:
raise ValueError("Given ground truth type is not valid.") raise ValueError("Given ground truth type is not valid.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment